Skip to content

Commit a2e76b9

Browse files
authored
🚨🚨 Switch default compilation to fullgraph=False (#40137)
* switch default * docstring * docstring * rework tests and remove outdated restrictions * simplify * we need a check for static cache * fix * rename var * fix * revert * style * rename test
1 parent 2b59207 commit a2e76b9

File tree

16 files changed

+38
-101
lines changed

16 files changed

+38
-101
lines changed

src/transformers/generation/configuration_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
logger = logging.get_logger(__name__)
4545
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
46-
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
46+
STATIC_CACHE_CLASSES_MAPPING = {}
4747
QUANT_BACKEND_CLASSES_MAPPING = {}
4848
ALL_CACHE_IMPLEMENTATIONS = []
4949

@@ -60,7 +60,7 @@
6060
)
6161
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
6262

63-
NEED_SETUP_CACHE_CLASSES_MAPPING = {
63+
STATIC_CACHE_CLASSES_MAPPING = {
6464
"static": StaticCache,
6565
"offloaded_static": OffloadedStaticCache,
6666
"sliding_window": SlidingWindowCache,
@@ -70,7 +70,7 @@
7070
"offloaded_hybrid_chunked": OffloadedHybridCache,
7171
}
7272
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
73-
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"]
73+
ALL_CACHE_IMPLEMENTATIONS = list(STATIC_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"]
7474

7575

7676
class GenerationMode(ExplicitEnum):
@@ -1536,8 +1536,10 @@ class CompileConfig:
15361536
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
15371537
15381538
Args:
1539-
fullgraph (`bool`, *optional*, defaults to `True`):
1540-
If `True`, requires that the whole forward be capturable in a single graph.
1539+
fullgraph (`bool`, *optional*, defaults to `False`):
1540+
If False (default), attempts to discover compileable regions that will be optimized. If True, then require
1541+
that the entire function be capturable into a single graph. If this is not possible (that is, if there are
1542+
graph breaks), then an error will be raised.
15411543
dynamic (`bool` or `None`, *optional*):
15421544
Whether to try to use dynamic shape graphs.
15431545
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
@@ -1566,7 +1568,7 @@ class CompileConfig:
15661568
```
15671569
"""
15681570

1569-
fullgraph: bool = True
1571+
fullgraph: bool = False
15701572
dynamic: Optional[bool] = None
15711573
backend: Union[str, Callable] = "inductor"
15721574
mode: str = "reduce-overhead"

src/transformers/generation/utils.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,8 @@
7171
_prepare_token_type_ids,
7272
)
7373
from .configuration_utils import (
74-
NEED_SETUP_CACHE_CLASSES_MAPPING,
7574
QUANT_BACKEND_CLASSES_MAPPING,
76-
CompileConfig,
75+
STATIC_CACHE_CLASSES_MAPPING,
7776
GenerationConfig,
7877
GenerationMode,
7978
)
@@ -1826,7 +1825,7 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
18261825
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
18271826
cache_implementation = "hybrid_chunked"
18281827

1829-
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
1828+
cache_cls: Cache = STATIC_CACHE_CLASSES_MAPPING[cache_implementation]
18301829
requires_cross_attention_cache = (
18311830
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
18321831
)
@@ -1958,12 +1957,7 @@ def _prepare_cache_for_generation(
19581957
else {}
19591958
)
19601959
if generation_config.cache_implementation is not None:
1961-
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
1962-
if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph:
1963-
raise ValueError(
1964-
"This model does not support `cache_implementation='static'`. Please check the following "
1965-
"issue: https://github.com/huggingface/transformers/issues/28981"
1966-
)
1960+
if generation_config.cache_implementation in STATIC_CACHE_CLASSES_MAPPING:
19671961
model_kwargs[cache_name] = self._get_cache(
19681962
cache_implementation=generation_config.cache_implementation,
19691963
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
@@ -2115,8 +2109,7 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge
21152109
using_compilable_cache = (
21162110
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
21172111
)
2118-
# TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile)
2119-
can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph
2112+
can_compile = valid_hardware and using_compilable_cache
21202113

21212114
# Exception 1: Some quantization methods do not support compilation
21222115
if getattr(self, "hf_quantizer", None) is not None:
@@ -3475,13 +3468,9 @@ def _sample(
34753468
if compile_forward:
34763469
os.environ["TOKENIZERS_PARALLELISM"] = "0"
34773470
# If we use FA2 and a static cache, we cannot compile with fullgraph
3478-
if self.config._attn_implementation == "flash_attention_2" and getattr(
3479-
model_kwargs.get("past_key_values"), "is_compileable", False
3480-
):
3481-
if generation_config.compile_config is None:
3482-
generation_config.compile_config = CompileConfig(fullgraph=False)
3483-
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
3484-
elif generation_config.compile_config.fullgraph:
3471+
if self.config._attn_implementation == "flash_attention_2":
3472+
# only raise warning if the user passed an explicit compile-config
3473+
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
34853474
logger.warning_once(
34863475
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
34873476
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."

src/transformers/models/deepseek_v2/modeling_deepseek_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
460460
_supports_flash_attn = True
461461
_supports_sdpa = True
462462
_supports_flex_attn = True
463-
464-
_can_compile_fullgraph = True
463+
_can_compile_fullgraph = False
465464
_supports_attention_backend = True
466465
_can_record_outputs = {
467466
"hidden_states": DeepseekV2DecoderLayer,

src/transformers/models/deepseek_v2/modular_deepseek_v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int):
504504

505505

506506
class DeepseekV2PreTrainedModel(LlamaPreTrainedModel):
507+
_can_compile_fullgraph = False
508+
507509
def _init_weights(self, module):
508510
LlamaPreTrainedModel._init_weights(module)
509511
if isinstance(module, DeepseekV2MoEGate):

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
502502
_supports_flash_attn = True
503503
_supports_sdpa = True
504504
_supports_flex_attn = True
505-
506-
_can_compile_fullgraph = True
505+
_can_compile_fullgraph = False
507506
_supports_attention_backend = True
508507
_can_record_outputs = {
509508
"hidden_states": DeepseekV3DecoderLayer,

src/transformers/models/deepseek_v3/modular_deepseek_v3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int):
340340

341341

342342
class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
343+
_can_compile_fullgraph = False
344+
343345
def _init_weights(self, module):
344346
LlamaPreTrainedModel._init_weights(module)
345347
if isinstance(module, DeepseekV3TopkRouter):

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
422422
_supports_flash_attn = True
423423
_supports_sdpa = True
424424
_supports_flex_attn = True
425-
426-
_can_compile_fullgraph = True
425+
_can_compile_fullgraph = False
427426
_supports_attention_backend = True
428427
_can_record_outputs = {
429428
"hidden_states": Dots1DecoderLayer,

tests/generation/test_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,9 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
17361736
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
17371737
"""
17381738
for model_class in self.all_generative_model_classes:
1739+
# Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly
1740+
# use a static cache because they don't create the causal masks correctly.
1741+
# TODO: cyril -> relax this by adding a `_support_static_cache` attribute
17391742
if not model_class._can_compile_fullgraph:
17401743
self.skipTest(reason="This model does not support the static cache format")
17411744

@@ -1956,6 +1959,9 @@ def test_generate_with_static_cache(self):
19561959
"""
19571960
set_model_tester_for_less_flaky_test(self)
19581961
for model_class in self.all_generative_model_classes:
1962+
# Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly
1963+
# use a static cache because they don't create the causal masks correctly.
1964+
# TODO: cyril -> relax this by adding a `_support_static_cache` attribute
19591965
if not model_class._can_compile_fullgraph:
19601966
self.skipTest(reason="This model does not support the static cache format")
19611967

@@ -2050,7 +2056,7 @@ def test_generate_with_quant_cache(self):
20502056
@pytest.mark.generate
20512057
@pytest.mark.torch_compile_test
20522058
@require_torch_greater_or_equal("2.6") # Uses torch.compiler.set_stance
2053-
def test_generate_compile_model_forward(self):
2059+
def test_generate_compile_model_forward_fullgraph(self):
20542060
"""
20552061
Tests that `.generate` is compatible with torch.compile, keeping the same results. Also confirms that
20562062
`.forward` called from `.generate` sees no graph breaks or recompilations when compiled.
@@ -2098,7 +2104,7 @@ def test_generate_compile_model_forward(self):
20982104
# 3. compilation-specific setup and generation parameterization
20992105
torch.compiler.reset() # prevent cached compilation from being used in the test
21002106
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
2101-
compile_config = CompileConfig(dynamic=False) # Error out on dynamic shapes
2107+
compile_config = CompileConfig(fullgraph=True, dynamic=False) # Error out on dynamic shapes
21022108
compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
21032109

21042110
generation_kwargs = {
@@ -2174,8 +2180,11 @@ def test_generate_compilation_all_outputs(self):
21742180
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
21752181
"""
21762182
for model_class in self.all_generative_model_classes:
2183+
# Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly
2184+
# use a static cache because they don't create the causal masks correctly.
2185+
# TODO: cyril -> relax this by adding a `_support_static_cache` attribute
21772186
if not model_class._can_compile_fullgraph:
2178-
self.skipTest("This model doesn't support compilation without graph breaks")
2187+
self.skipTest(reason="This model does not support the static cache format")
21792188

21802189
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
21812190
if self.has_attentions:

tests/models/deepseek_v2/test_modeling_deepseek_v2.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -174,25 +174,6 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value
174174
self.assertEqual(layer.keys.shape, expected_key_shape)
175175
self.assertEqual(layer.values.shape, expected_value_shape)
176176

177-
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
178-
@pytest.mark.torch_compile_test
179-
def test_generate_compilation_all_outputs(self):
180-
pass
181-
182-
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
183-
@pytest.mark.torch_compile_test
184-
def test_generate_compile_model_forward(self):
185-
pass
186-
187-
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
188-
def test_generate_from_inputs_embeds_with_static_cache(self):
189-
pass
190-
191-
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
192-
@pytest.mark.torch_compile_test
193-
def test_generate_with_static_cache(self):
194-
pass
195-
196177
@unittest.skip("Dynamic control flow in MoE")
197178
@pytest.mark.torch_compile_test
198179
def test_torch_compile_for_training(self):

tests/models/deepseek_v3/test_modeling_deepseek_v3.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,6 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
285285
def test_contrastive_generate_low_memory(self):
286286
pass
287287

288-
@unittest.skip(
289-
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
290-
)
291-
def test_generate_with_static_cache(self):
292-
pass
293-
294-
@unittest.skip(
295-
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
296-
)
297-
def test_generate_from_inputs_embeds_with_static_cache(self):
298-
pass
299-
300288
@unittest.skip(
301289
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
302290
)
@@ -307,15 +295,6 @@ def test_generate_continue_from_inputs_embeds(self):
307295
def test_beam_search_generate_dict_outputs_use_cache(self):
308296
pass
309297

310-
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
311-
def test_generate_compilation_all_outputs(self):
312-
pass
313-
314-
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
315-
@pytest.mark.torch_compile_test
316-
def test_generate_compile_model_forward(self):
317-
pass
318-
319298
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
320299
def test_greedy_generate_dict_outputs_use_cache(self):
321300
pass

0 commit comments

Comments
 (0)