Skip to content

Commit 41d1717

Browse files
authored
New DynamicSlidingWindowLayer & associated Cache (#40039)
* start adding the layer * style * improve * modular * fix * fix * improve * generate integration * comment * remove old one * remove * fix * fix * fix * fix all recompiles * fix * doc * fix * add text config check * fix encoderdecoder cache * add it for all models with sliding/hybrid support * revert * start fixing * prophetnet * fsmt * fix ddp_data * add test for mistral * improve mistral test and add gemma2 test * docstrings
1 parent ab455e0 commit 41d1717

File tree

61 files changed

+508
-320
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+508
-320
lines changed

src/transformers/cache_utils.py

Lines changed: 331 additions & 254 deletions
Large diffs are not rendered by default.

src/transformers/generation/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,7 +1813,7 @@ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
18131813
# Support for BC tuple cache format
18141814
if isinstance(cache, tuple):
18151815
past_length = cache[0][0].shape[2]
1816-
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
1816+
elif hasattr(cache, "get_seq_length"):
18171817
past_length = cache.get_seq_length()
18181818

18191819
cache_position = cache_position[past_length:]
@@ -1959,6 +1959,16 @@ def _prepare_cache_for_generation(
19591959
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
19601960
self.config.get_text_config(decoder=True), "cache_implementation", None
19611961
)
1962+
1963+
# assisted decoding and contrastive search need to roll-back the Cache, which is not supported if
1964+
# it has sliding layers - so if we use any of those 2, do not pass the config to DynamicCache, which
1965+
# will result in creating a Cache with only full layers even if model uses sliding window
1966+
generation_mode = generation_config.get_generation_mode(assistant_model)
1967+
dynamic_cache_kwargs = (
1968+
{"config": self.config}
1969+
if generation_mode not in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH)
1970+
else {}
1971+
)
19621972
if generation_config.cache_implementation is not None:
19631973
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
19641974
if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph:
@@ -2001,15 +2011,15 @@ def _prepare_cache_for_generation(
20012011
elif generation_config.cache_implementation == "offloaded":
20022012
model_kwargs[cache_name] = OffloadedCache()
20032013
elif generation_config.cache_implementation == "dynamic":
2004-
model_kwargs[cache_name] = DynamicCache()
2014+
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
20052015

20062016
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
20072017
# keeps copying the cache thus using much more memory
20082018
else:
20092019
model_kwargs[cache_name] = (
2010-
DynamicCache()
2020+
DynamicCache(**dynamic_cache_kwargs)
20112021
if not requires_cross_attention_cache
2012-
else EncoderDecoderCache(DynamicCache(), DynamicCache())
2022+
else EncoderDecoderCache(DynamicCache(**dynamic_cache_kwargs), DynamicCache(**dynamic_cache_kwargs))
20132023
)
20142024

20152025
def _supports_logits_to_keep(self) -> bool:

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def forward(
364364
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
365365

366366
if use_cache and past_key_values is None:
367-
past_key_values = DynamicCache()
367+
past_key_values = DynamicCache(config=self.config)
368368

369369
if cache_position is None:
370370
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ def forward(
744744
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
745745

746746
if use_cache and past_key_values is None:
747-
past_key_values = DynamicCache()
747+
past_key_values = DynamicCache(config=self.config)
748748

749749
if cache_position is None:
750750
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def forward(
363363
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
364364

365365
if use_cache and past_key_values is None:
366-
past_key_values = DynamicCache()
366+
past_key_values = DynamicCache(config=self.config)
367367

368368
if cache_position is None:
369369
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def forward(
396396
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
397397

398398
if use_cache and past_key_values is None:
399-
past_key_values = DynamicCache()
399+
past_key_values = DynamicCache(config=self.config)
400400

401401
if cache_position is None:
402402
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def forward(
371371
inputs_embeds = self.embed_tokens(input_ids)
372372

373373
if use_cache and past_key_values is None and not self.training:
374-
past_key_values = DynamicCache()
374+
past_key_values = DynamicCache(config=self.config)
375375

376376
if cache_position is None:
377377
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def forward(
405405
inputs_embeds = self.embed_tokens(input_ids)
406406

407407
if use_cache and past_key_values is None and not self.training:
408-
past_key_values = DynamicCache()
408+
past_key_values = DynamicCache(config=self.config)
409409

410410
if cache_position is None:
411411
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/csm/modeling_csm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def forward(
702702
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
703703

704704
if use_cache and past_key_values is None:
705-
past_key_values = DynamicCache()
705+
past_key_values = DynamicCache(config=self.config)
706706

707707
if cache_position is None:
708708
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

src/transformers/models/deepseek_v2/modeling_deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def forward(
512512
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
513513

514514
if use_cache and past_key_values is None:
515-
past_key_values = DynamicCache()
515+
past_key_values = DynamicCache(config=self.config)
516516

517517
if cache_position is None:
518518
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

0 commit comments

Comments
 (0)