Skip to content

Commit 58e788f

Browse files
authored
Feature kv cache (#1045)
* created seperate cache functions * moved components and decoupled config * refactored references * updated transformer lens cache naming * replaced references * ran format * added full cache compatibility * fixed type issues * removed extra param setting
1 parent 2106d89 commit 58e788f

File tree

16 files changed

+1177
-162
lines changed

16 files changed

+1177
-162
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,37 @@ def test_text_generation():
6767
assert len(output) > len(prompt), "Generated text should be longer than the prompt"
6868

6969

70+
def test_generate_with_kv_cache():
71+
"""Test that generate works with use_past_kv_cache parameter."""
72+
model_name = "gpt2" # Use a smaller model for testing
73+
bridge = TransformerBridge.boot_transformers(model_name)
74+
75+
if bridge.tokenizer.pad_token is None:
76+
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
77+
78+
prompt = "The quick brown fox jumps over the lazy dog"
79+
80+
# Test with KV cache enabled
81+
output_with_cache = bridge.generate(prompt, max_new_tokens=5, use_past_kv_cache=True)
82+
83+
# Test with KV cache disabled
84+
output_without_cache = bridge.generate(prompt, max_new_tokens=5, use_past_kv_cache=False)
85+
86+
# Both should produce valid outputs
87+
assert isinstance(output_with_cache, str), "Output with KV cache should be a string"
88+
assert isinstance(output_without_cache, str), "Output without KV cache should be a string"
89+
assert len(output_with_cache) > len(
90+
prompt
91+
), "Generated text with KV cache should be longer than the prompt"
92+
assert len(output_without_cache) > len(
93+
prompt
94+
), "Generated text without KV cache should be longer than the prompt"
95+
96+
# The outputs might be different due to sampling, but both should be valid
97+
assert len(output_with_cache) > 0, "Output with KV cache should not be empty"
98+
assert len(output_without_cache) > 0, "Output without KV cache should not be empty"
99+
100+
70101
def test_hooks():
71102
"""Test that hooks can be added and removed correctly."""
72103
model_name = "gpt2" # Use a smaller model for testing

0 commit comments

Comments
 (0)