Skip to content

Commit dc11a3c

Browse files
authored
[core] Refactor the Cache logic to make it simpler and more general (#39797)
* Simplify the logic quite a bit * Update cache_utils.py * continue work * continue simplifying a lot * style * Update cache_utils.py * offloading much simpler * style * Update cache_utils.py * update inits * Update cache_utils.py * consistemncy * Update cache_utils.py * update generate * style * fix * fix * add early_initialization * fix * fix mamba caches * update * fix * fix * fix * fix tests * fix configs * revert * fix tests * alright * Update modeling_gptj.py * fix the constructors * cache tests * Update test_cache_utils.py * fix * simplify * back to before -> avoid compile bug * doc * mistral test * llama4 test dtype * Update test_modeling_llama4.py * CIs * Finally find a nice impl * Update cache_utils.py * Update cache_utils.py * add lazy methods in autodoc * typo * better doc * Add detailed docstring for lazy init * CIs * style * fix
1 parent 95510ab commit dc11a3c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+779
-1449
lines changed

docs/source/en/internal/generation_utils.md

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -363,37 +363,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens
363363
- get_max_cache_shape
364364
- reset
365365
- reorder_cache
366+
- lazy_initialization
366367

367368
[[autodoc]] DynamicLayer
368369
- update
370+
- lazy_initialization
369371
- crop
370372
- batch_repeat_interleave
371373
- batch_select_indices
372374

373375
[[autodoc]] StaticLayer
374376
- update
377+
- lazy_initialization
375378

376379
[[autodoc]] SlidingWindowLayer
377380
- update
381+
- lazy_initialization
378382

379-
[[autodoc]] CacheProcessor
380-
- pre_update
381-
- post_update
382-
383-
[[autodoc]] OffloadedCacheProcessor
384-
- pre_update
385-
386-
[[autodoc]] QuantizedCacheProcessor
387-
- post_update
388-
389-
[[autodoc]] QuantoQuantizedCacheProcessor
390-
- post_update
383+
[[autodoc]] QuantoQuantizedLayer
384+
- update
385+
- lazy_initialization
391386

392-
[[autodoc]] HQQQuantizedCacheProcessor
393-
- post_update
387+
[[autodoc]] HQQQuantizedLayer
388+
- update
389+
- lazy_initialization
394390

395391
[[autodoc]] Cache
396392
- update
393+
- early_initialization
397394
- get_seq_length
398395
- get_mask_sizes
399396
- get_max_cache_shape
@@ -411,12 +408,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
411408

412409
[[autodoc]] QuantoQuantizedCache
413410

414-
[[autodoc]] QuantoQuantizedCacheProcessor
415-
416411
[[autodoc]] HQQQuantizedCache
417412

418-
[[autodoc]] HQQQuantizedCacheProcessor
419-
420413
[[autodoc]] OffloadedCache
421414

422415
[[autodoc]] StaticCache

docs/source/en/kv_cache.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
312312

313313
# Init StaticCache with big enough max-length (1024 tokens for the below example)
314314
# You can also init a DynamicCache, if that suits you better
315-
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=model.device.type, dtype=torch.bfloat16)
315+
prompt_cache = StaticCache(config=model.config, max_cache_len=1024)
316316

317317
INITIAL_PROMPT = "You are a helpful assistant. "
318318
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(model.device.type)

docs/source/en/llm_optims.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,8 @@ model.generation_config.max_new_tokens = 16
9393

9494
past_key_values = StaticCache(
9595
config=model.config,
96-
max_batch_size=1,
9796
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
9897
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
99-
device=model.device,
100-
dtype=model.dtype
10198
)
10299
outputs = model.generate(**input_ids, past_key_values=past_key_values)
103100
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
@@ -159,7 +156,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
159156
batch_size, seq_length = inputs["input_ids"].shape
160157
with torch.no_grad():
161158
past_key_values = StaticCache(
162-
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
159+
config=model.config, max_cache_len=4096
163160
)
164161
cache_position = torch.arange(seq_length, device=torch_device)
165162
generated_ids = torch.zeros(

docs/source/en/model_doc/gemma2.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ visualizer("You are an assistant. Make sure you print me")
138138

139139
inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
140140
max_generated_length = inputs.input_ids.shape[1] + 10
141-
past_key_values = HybridCache(config=model.config, max_batch_size=1,
142-
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
141+
past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length)
143142
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
144143
```
145144

docs/source/ko/internal/generation_utils.md

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -362,21 +362,11 @@ generation_output[:2]
362362
[[autodoc]] SlidingWindowLayer
363363
- update
364364

365-
[[autodoc]] CacheProcessor
366-
- pre_update
367-
- post_update
368-
369-
[[autodoc]] OffloadedCacheProcessor
370-
- pre_update
371-
372-
[[autodoc]] QuantizedCacheProcessor
373-
- post_update
374-
375-
[[autodoc]] QuantoQuantizedCacheProcessor
376-
- post_update
365+
[[autodoc]] QuantoQuantizedLayer
366+
- update
377367

378-
[[autodoc]] HQQQuantizedCacheProcessor
379-
- post_update
368+
[[autodoc]] HQQQuantizedLayer
369+
- update
380370

381371
[[autodoc]] Cache
382372
- update
@@ -397,12 +387,8 @@ generation_output[:2]
397387

398388
[[autodoc]] QuantoQuantizedCache
399389

400-
[[autodoc]] QuantoQuantizedCacheProcessor
401-
402390
[[autodoc]] HQQQuantizedCache
403391

404-
[[autodoc]] HQQQuantizedCacheProcessor
405-
406392
[[autodoc]] OffloadedCache
407393

408394
[[autodoc]] StaticCache

docs/source/ko/llm_optims.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,8 @@ model.generation_config.max_new_tokens = 16
9999

100100
past_key_values = StaticCache(
101101
config=model.config,
102-
max_batch_size=1,
103102
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
104103
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
105-
device=model.device,
106-
dtype=model.dtype
107104
)
108105
outputs = model.generate(**input_ids, past_key_values=past_key_values)
109106
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
@@ -161,7 +158,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
161158
batch_size, seq_length = inputs["input_ids"].shape
162159
with torch.no_grad():
163160
past_key_values = StaticCache(
164-
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
161+
config=model.config, max_cache_len=4096
165162
)
166163
cache_position = torch.arange(seq_length, device=torch_device)
167164
generated_ids = torch.zeros(

src/transformers/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,23 +377,18 @@
377377
"StaticLayer",
378378
"SlidingWindowLayer",
379379
"ChunkedSlidingLayer",
380-
"CacheProcessor",
381-
"OffloadedCacheProcessor",
382-
"QuantizedCacheProcessor",
383-
"QuantoQuantizedCacheProcessor",
384-
"HQQQuantizedCacheProcessor",
380+
"QuantoQuantizedLayer",
381+
"HQQQuantizedLayer",
385382
"Cache",
386383
"CacheConfig",
387384
"DynamicCache",
388385
"EncoderDecoderCache",
389386
"HQQQuantizedCache",
390-
"HQQQuantizedCacheProcessor",
391387
"HybridCache",
392388
"HybridChunkedCache",
393389
"OffloadedCache",
394390
"OffloadedStaticCache",
395391
"QuantizedCache",
396-
"QuantoQuantizedCacheProcessor",
397392
"QuantizedCacheConfig",
398393
"QuantoQuantizedCache",
399394
"SinkCache",
@@ -586,19 +581,25 @@
586581
# All modeling imports
587582
from .cache_utils import Cache as Cache
588583
from .cache_utils import CacheConfig as CacheConfig
584+
from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer
589585
from .cache_utils import DynamicCache as DynamicCache
586+
from .cache_utils import DynamicLayer as DynamicLayer
590587
from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
591588
from .cache_utils import HQQQuantizedCache as HQQQuantizedCache
589+
from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer
592590
from .cache_utils import HybridCache as HybridCache
593591
from .cache_utils import MambaCache as MambaCache
594592
from .cache_utils import OffloadedCache as OffloadedCache
595593
from .cache_utils import OffloadedStaticCache as OffloadedStaticCache
596594
from .cache_utils import QuantizedCache as QuantizedCache
597595
from .cache_utils import QuantizedCacheConfig as QuantizedCacheConfig
598596
from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache
597+
from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer
599598
from .cache_utils import SinkCache as SinkCache
600599
from .cache_utils import SlidingWindowCache as SlidingWindowCache
600+
from .cache_utils import SlidingWindowLayer as SlidingWindowLayer
601601
from .cache_utils import StaticCache as StaticCache
602+
from .cache_utils import StaticLayer as StaticLayer
602603
from .configuration_utils import PretrainedConfig as PretrainedConfig
603604
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
604605
from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer

0 commit comments

Comments
 (0)