Skip to content

Commit 242bb2c

Browse files
authored
One cache class to rule them all (#40276)
* remove all classes * fix generate * start replacing everywhere * finish removing everywhere * typo * typo * fix * typo * remove num_layers=1 * CI * fix all docstrings * review * style
1 parent 1054494 commit 242bb2c

File tree

16 files changed

+300
-577
lines changed

16 files changed

+300
-577
lines changed

src/transformers/cache_utils.py

Lines changed: 189 additions & 285 deletions
Large diffs are not rendered by default.

src/transformers/generation/configuration_utils.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,35 +43,23 @@
4343

4444
logger = logging.get_logger(__name__)
4545
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
46-
STATIC_CACHE_CLASSES_MAPPING = {}
47-
QUANT_BACKEND_CLASSES_MAPPING = {}
48-
ALL_CACHE_IMPLEMENTATIONS = []
46+
STATIC_CACHE_IMPLEMENTATIONS = ("static", "offloaded_static")
47+
DYNAMIC_CACHE_IMPLEMENTATIONS = ("dynamic", "offloaded", "quantized")
48+
# All the following are redundant and deprecated, but kept for BC
49+
DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS = (
50+
"sliding_window",
51+
"hybrid",
52+
"hybrid_chunked",
53+
"offloaded_hybrid",
54+
"offloaded_hybrid_chunked",
55+
)
56+
ALL_STATIC_CACHE_IMPLEMENTATIONS = STATIC_CACHE_IMPLEMENTATIONS + DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS
57+
ALL_CACHE_IMPLEMENTATIONS = ALL_STATIC_CACHE_IMPLEMENTATIONS + DYNAMIC_CACHE_IMPLEMENTATIONS
58+
4959

5060
if is_torch_available():
51-
from ..cache_utils import (
52-
HQQQuantizedCache,
53-
HybridCache,
54-
HybridChunkedCache,
55-
OffloadedHybridCache,
56-
OffloadedStaticCache,
57-
QuantoQuantizedCache,
58-
SlidingWindowCache,
59-
StaticCache,
60-
)
6161
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
6262

63-
STATIC_CACHE_CLASSES_MAPPING = {
64-
"static": StaticCache,
65-
"offloaded_static": OffloadedStaticCache,
66-
"sliding_window": SlidingWindowCache,
67-
"hybrid": HybridCache,
68-
"hybrid_chunked": HybridChunkedCache,
69-
"offloaded_hybrid": OffloadedHybridCache,
70-
"offloaded_hybrid_chunked": OffloadedHybridCache,
71-
}
72-
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
73-
ALL_CACHE_IMPLEMENTATIONS = list(STATIC_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"]
74-
7563

7664
class GenerationMode(ExplicitEnum):
7765
"""
@@ -173,9 +161,8 @@ class GenerationConfig(PushToHubMixin):
173161
174162
- `"dynamic"`: [`DynamicCache`]
175163
- `"static"`: [`StaticCache`]
176-
- `"offloaded_static"`: [`OffloadedStaticCache`]
177-
- `"sliding_window"`: [`SlidingWindowCache`]
178-
- `"hybrid"`: [`HybridCache`]
164+
- `"offloaded"`: [`DynamicCache(offloaded=True)`]
165+
- `"offloaded_static"`: [`StaticCache(offloaded=True)`]
179166
- `"quantized"`: [`QuantizedCache`]
180167
181168
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See

src/transformers/generation/utils.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
Cache,
3333
DynamicCache,
3434
EncoderDecoderCache,
35-
HybridChunkedCache,
36-
OffloadedCache,
37-
OffloadedHybridCache,
35+
QuantizedCache,
36+
StaticCache,
3837
)
3938
from ..configuration_utils import PretrainedConfig
4039
from ..dynamic_module_utils import (
@@ -71,8 +70,9 @@
7170
_prepare_token_type_ids,
7271
)
7372
from .configuration_utils import (
74-
QUANT_BACKEND_CLASSES_MAPPING,
75-
STATIC_CACHE_CLASSES_MAPPING,
73+
ALL_STATIC_CACHE_IMPLEMENTATIONS,
74+
DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS,
75+
STATIC_CACHE_IMPLEMENTATIONS,
7676
GenerationConfig,
7777
GenerationMode,
7878
)
@@ -1822,27 +1822,18 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
18221822
18231823
Returns the resulting cache object.
18241824
"""
1825-
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
1826-
cache_implementation = "hybrid_chunked"
1827-
1828-
cache_cls: Cache = STATIC_CACHE_CLASSES_MAPPING[cache_implementation]
18291825
requires_cross_attention_cache = (
18301826
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
18311827
)
1828+
offload_cache = "offloaded" in cache_implementation
18321829

18331830
if hasattr(self, "_cache"):
18341831
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
18351832

1836-
if cache_implementation == "sliding_window":
1837-
max_cache_len = min(self.config.sliding_window, max_cache_len)
1838-
18391833
need_new_cache = (
18401834
not hasattr(self, "_cache")
1841-
or (not isinstance(cache_to_check, cache_cls))
1835+
or cache_to_check.offloading != offload_cache
18421836
or cache_to_check.max_batch_size != batch_size
1843-
or isinstance(
1844-
cache_to_check, (HybridChunkedCache, OffloadedHybridCache)
1845-
) # due to internal slicing, we always re-init
18461837
or cache_to_check.max_cache_len < max_cache_len
18471838
)
18481839

@@ -1853,12 +1844,12 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
18531844
)
18541845

18551846
if need_new_cache:
1856-
cache_kwargs = {"config": self.config.get_text_config(), "max_cache_len": max_cache_len}
1857-
self._cache = cache_cls(**cache_kwargs)
1847+
cache_kwargs = {"config": self.config, "max_cache_len": max_cache_len, "offloading": offload_cache}
1848+
self._cache = StaticCache(**cache_kwargs)
18581849
if requires_cross_attention_cache:
18591850
encoder_kwargs = cache_kwargs.copy()
18601851
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
1861-
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
1852+
self._cache = EncoderDecoderCache(self._cache, StaticCache(**encoder_kwargs))
18621853
else:
18631854
self._cache.reset()
18641855
return self._cache
@@ -1957,7 +1948,12 @@ def _prepare_cache_for_generation(
19571948
else {}
19581949
)
19591950
if generation_config.cache_implementation is not None:
1960-
if generation_config.cache_implementation in STATIC_CACHE_CLASSES_MAPPING:
1951+
if generation_config.cache_implementation in ALL_STATIC_CACHE_IMPLEMENTATIONS:
1952+
if generation_config.cache_implementation in DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS:
1953+
logger.warning_once(
1954+
f"Using `cache_implementation='{generation_config.cache_implementation}' is deprecated. Please only "
1955+
f"use one of {STATIC_CACHE_IMPLEMENTATIONS}, and the layer structure will be inferred automatically."
1956+
)
19611957
model_kwargs[cache_name] = self._get_cache(
19621958
cache_implementation=generation_config.cache_implementation,
19631959
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
@@ -1977,7 +1973,6 @@ def _prepare_cache_for_generation(
19771973
cache_config["config"] = self.config.get_text_config()
19781974
# Pop the backend from the config (defaults to quanto if not defined)
19791975
backend = cache_config.pop("backend", "quanto")
1980-
cache_class = QUANT_BACKEND_CLASSES_MAPPING[backend]
19811976

19821977
if backend == "quanto" and not is_optimum_quanto_available():
19831978
raise ImportError(
@@ -1989,10 +1984,9 @@ def _prepare_cache_for_generation(
19891984
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
19901985
"Please install it via with `pip install hqq`"
19911986
)
1992-
1993-
model_kwargs[cache_name] = cache_class(**cache_config)
1987+
model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config)
19941988
elif generation_config.cache_implementation == "offloaded":
1995-
model_kwargs[cache_name] = OffloadedCache()
1989+
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs, offloading=True)
19961990
elif generation_config.cache_implementation == "dynamic":
19971991
model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
19981992

src/transformers/integrations/executorch.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
DynamicLayer,
2121
DynamicSlidingWindowLayer,
2222
EncoderDecoderCache,
23-
HybridCache,
2423
StaticCache,
2524
)
2625
from ..generation.configuration_utils import GenerationConfig
@@ -38,9 +37,6 @@
3837
)
3938

4039

41-
# Add this to src/transformers/integrations/executorch.py
42-
43-
4440
class TorchExportableModuleForVLM:
4541
"""
4642
A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
@@ -207,7 +203,7 @@ def __init__(
207203
model: PreTrainedModel,
208204
):
209205
"""
210-
Initializes the exportable module with `HybridCache`.
206+
Initializes the exportable module.
211207
212208
Args:
213209
model (`PreTrainedModel`): The pretrained model to wrap.
@@ -636,7 +632,7 @@ def generate(
636632
class TorchExportableModuleWithHybridCache(torch.nn.Module):
637633
"""
638634
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
639-
specifically for decoder-only LM to `HybridCache`. This module ensures that the
635+
specifically for decoder-only LM to hybrid `StaticCache`. This module ensures that the
640636
exported model is compatible with further lowering and execution in `ExecuTorch`.
641637
"""
642638

@@ -645,13 +641,13 @@ def __init__(
645641
model: PreTrainedModel,
646642
):
647643
"""
648-
Initializes the exportable module with `HybridCache`.
644+
Initializes the exportable module.
649645
650646
Args:
651647
model (`PreTrainedModel`): The pretrained model to wrap.
652648
653649
Raises:
654-
AssertionError: If the model doesn't have the expected configuration for HybridCache.
650+
AssertionError: If the model doesn't have the expected configuration for an hybrid StaticCache.
655651
"""
656652
super().__init__()
657653
self.model = model
@@ -676,8 +672,8 @@ def __init__(
676672
if not config.use_cache:
677673
raise AssertionError("Model must have caching enabled.")
678674

679-
# Initialize the HybridCache
680-
self.cache = HybridCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len"))
675+
# Initialize the cache
676+
self.cache = StaticCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len"))
681677
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
682678
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
683679
max_batch_size = generation_config.cache_config.get("batch_size")

src/transformers/masking_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def create_causal_mask(
736736
) -> Optional[Union[torch.Tensor, BlockMask]]:
737737
"""
738738
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
739-
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
739+
has an hybrid cache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
740740
to what is needed in the `modeling_xxx.py` files).
741741
742742
Args:
@@ -761,7 +761,7 @@ def create_causal_mask(
761761
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
762762
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
763763
"""
764-
# If we have an HybridCache structure, here we want to create the mask for the full layers
764+
# If we have an hybrid cache structure, here we want to create the mask for the full layers
765765
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
766766
layer_idx = past_key_values.is_sliding.index(False)
767767
else:
@@ -828,7 +828,7 @@ def create_sliding_window_causal_mask(
828828
) -> Optional[Union[torch.Tensor, BlockMask]]:
829829
"""
830830
Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
831-
of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this
831+
of attention pattern was mostly democratized by Mistral. If `past_key_values` has an hybrid cache structure, this
832832
function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
833833
`modeling_xxx.py` files).
834834
@@ -854,7 +854,7 @@ def create_sliding_window_causal_mask(
854854
An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
855855
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
856856
"""
857-
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
857+
# If we have an hybrid cache structure, here we want to create the mask for the sliding layers
858858
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
859859
layer_idx = past_key_values.is_sliding.index(True)
860860
else:
@@ -923,7 +923,7 @@ def create_chunked_causal_mask(
923923
) -> Optional[Union[torch.Tensor, BlockMask]]:
924924
"""
925925
Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
926-
of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this
926+
of attention pattern was mostly democratized by Llama4. If `past_key_values` has an hybrid cache structure, this
927927
function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the
928928
`modeling_xxx.py` files).
929929
@@ -949,7 +949,7 @@ def create_chunked_causal_mask(
949949
An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
950950
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
951951
"""
952-
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
952+
# If we have an hybrid cache structure, here we want to create the mask for the sliding layers
953953
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
954954
layer_idx = past_key_values.is_sliding.index(True)
955955
else:

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ class BambaMixer(nn.Module):
467467
and is why Mamba is called **selective** state spaces)
468468
469469
The are a few differences between this and Mamba2Mixer:
470-
- The variable use_precomputed_states is slightly different due to the HybridCache structure
470+
- The variable use_precomputed_states is slightly different due to the hybrid cache structure
471471
- There's a few non-obvious bugs fixed with batching in the slow path that exist in main
472472
- Some extra variables that our layer doesn't need have been removed
473473
- We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged

src/transformers/models/bamba/modular_bamba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ class BambaMixer(nn.Module):
225225
and is why Mamba is called **selective** state spaces)
226226
227227
The are a few differences between this and Mamba2Mixer:
228-
- The variable use_precomputed_states is slightly different due to the HybridCache structure
228+
- The variable use_precomputed_states is slightly different due to the hybrid cache structure
229229
- There's a few non-obvious bugs fixed with batching in the slow path that exist in main
230230
- Some extra variables that our layer doesn't need have been removed
231231
- We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged

src/transformers/models/gpt_neo/modeling_gpt_neo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
477477
_no_split_modules = ["GPTNeoBlock"]
478478
_skip_keys_device_placement = "past_key_values"
479479
_supports_flash_attn = True
480-
_can_compile_fullgraph = False # TODO: needs a HybridCache
480+
_can_compile_fullgraph = False # TODO: needs a hybrid cache
481481

482482
def __init__(self, *inputs, **kwargs):
483483
super().__init__(*inputs, **kwargs)

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
394394
and is why Mamba is called **selective** state spaces)
395395
396396
The are a few differences between this and Mamba2Mixer:
397-
- The variable use_precomputed_states is slightly different due to the HybridCache structure
397+
- The variable use_precomputed_states is slightly different due to the hybrid cache structure
398398
- There's a few non-obvious bugs fixed with batching in the slow path that exist in main
399399
- Some extra variables that our layer doesn't need have been removed
400400
- We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged

src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch.nn as nn
2828

2929
from ...activations import ACT2FN
30-
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
30+
from ...cache_utils import Cache, DynamicCache, StaticCache
3131
from ...generation import GenerationConfig, GenerationMixin
3232
from ...modeling_attn_mask_utils import AttentionMaskConverter
3333
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
@@ -945,14 +945,9 @@ def _update_causal_mask(
945945
# to infer the attention mask.
946946
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
947947
using_static_cache = isinstance(past_key_values, StaticCache)
948-
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
949948

950949
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
951-
if (
952-
self.config._attn_implementation == "sdpa"
953-
and not (using_static_cache or using_sliding_window_cache)
954-
and not output_attentions
955-
):
950+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
956951
if AttentionMaskConverter._ignore_causal_mask_sdpa(
957952
attention_mask,
958953
inputs_embeds=input_tensor,
@@ -965,8 +960,8 @@ def _update_causal_mask(
965960
dtype = input_tensor.dtype
966961
min_dtype = torch.finfo(dtype).min
967962
sequence_length = input_tensor.shape[1]
968-
# SlidingWindowCache or StaticCache
969-
if using_sliding_window_cache or using_static_cache:
963+
# StaticCache
964+
if using_static_cache:
970965
target_length = past_key_values.get_max_cache_shape()
971966
# DynamicCache or no cache
972967
else:
@@ -1049,7 +1044,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
10491044
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
10501045
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
10511046
# the check is needed to verify is current checkpoint was trained with sliding window or not
1052-
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
1047+
is_static_sliding_cache = isinstance(past_key_values, StaticCache) and all(past_key_values.is_sliding)
1048+
if not is_static_sliding_cache or sequence_length > target_length:
10531049
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
10541050
cache_position.reshape(-1, 1) - text_config.sliding_window
10551051
)

0 commit comments

Comments
 (0)