Skip to content

Commit e6f9b03

Browse files
authored
[Compile] Only test compiling model forward pass (#35658)
* rename test to only compile forward! * style emu
1 parent 84a6789 commit e6f9b03

File tree

7 files changed

+9
-34
lines changed

7 files changed

+9
-34
lines changed

tests/generation/test_utils.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,16 +2042,10 @@ def test_generate_with_quant_cache(self):
20422042
with self.assertRaises(ValueError):
20432043
model.generate(**generation_kwargs, **inputs_dict)
20442044

2045-
@parameterized.expand(
2046-
[
2047-
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
2048-
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
2049-
]
2050-
)
20512045
@pytest.mark.generate
20522046
@require_torch_gpu
20532047
@slow
2054-
def test_generate_compile(self, _, end_to_end):
2048+
def test_generate_compile_model_forward(self):
20552049
"""
20562050
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
20572051
end-to-end compilation and forward pass compilation only.
@@ -2061,14 +2055,7 @@ def test_generate_compile(self, _, end_to_end):
20612055
if not model_class._supports_static_cache:
20622056
self.skipTest("This model doesn't support static cache")
20632057

2064-
# TODO (joao) -- fix and enable me :)
2065-
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
2066-
self.skipTest("whisper model end-to-end generate compile not yet supported")
2067-
20682058
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
2069-
# TODO (joao) -- fix and enable me :)
2070-
if end_to_end and config.is_encoder_decoder:
2071-
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
20722059

20732060
model = model_class(config).to(torch_device)
20742061
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
@@ -2084,10 +2071,8 @@ def test_generate_compile(self, _, end_to_end):
20842071
"max_new_tokens": 10,
20852072
"return_dict_in_generate": True,
20862073
"output_scores": True,
2074+
"cache_implementation": "static",
20872075
}
2088-
# end-to-end works best with dynamic cache, forward compilation works best with static cache
2089-
if not end_to_end:
2090-
generation_kwargs["cache_implementation"] = "static"
20912076

20922077
# get eager + dynamic cache results for future comparison
20932078
dynamic_outputs = []
@@ -2098,10 +2083,8 @@ def test_generate_compile(self, _, end_to_end):
20982083
generation_config = copy.deepcopy(model.generation_config)
20992084
generation_config.update(**generation_kwargs)
21002085
torch.compiler.reset()
2101-
if end_to_end:
2102-
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
2103-
else:
2104-
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
2086+
2087+
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
21052088

21062089
compiled_outputs = []
21072090
for model_inputs in input_ids_sets:

tests/models/chameleon/test_modeling_chameleon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def test_batching_equivalence(self):
333333

334334
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
335335
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
336-
def test_generate_compile_fullgraph(self):
336+
def test_generate_compile_model_forward(self):
337337
pass
338338

339339

tests/models/dbrx/test_modeling_dbrx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def test_disk_offload_bin(self):
369369
pass
370370

371371
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
372-
def test_generate_compile_fullgraph(self):
372+
def test_generate_compile_model_forward(self):
373373
pass
374374

375375

tests/models/emu3/test_modeling_emu3.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,6 @@ def test_model_rope_scaling(self, scaling_type):
176176
def test_custom_4d_attention_mask(self):
177177
pass
178178

179-
@unittest.skip("Fails with unknown error only on end-to-end compile") # TODO raushan fixme
180-
def test_generate_compile_1_end_to_end(self):
181-
pass
182-
183179

184180
class Emu3Vision2TextModelTester:
185181
def __init__(
@@ -398,10 +394,6 @@ def test_custom_4d_attention_mask(self):
398394
def test_initialization(self):
399395
pass
400396

401-
@unittest.skip("End-to-end compilation is not supported due to dynamic control in `prepare_inputs_for_generation`")
402-
def test_generate_compile_1_end_to_end(self):
403-
pass
404-
405397

406398
@require_torch
407399
class Emu3IntegrationTest(unittest.TestCase):

tests/models/idefics/test_modeling_idefics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def test_custom_4d_attention_mask(self):
781781
pass
782782

783783
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
784-
def test_generate_compile_fullgraph(self):
784+
def test_generate_compile_model_forward(self):
785785
pass
786786

787787
@unittest.skip(reason="We only test the model that takes in multiple images")

tests/models/paligemma/test_modeling_paligemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
348348

349349
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
350350
@unittest.skip("PaliGemma is not compatible with end-to-end generation compilation")
351-
def test_generate_compile_fullgraph(self):
351+
def test_generate_compile_model_forward(self):
352352
pass
353353

354354

tests/models/qwen2_vl/test_modeling_qwen2_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
333333
pass
334334

335335
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
336-
def test_generate_compile_fullgraph(self):
336+
def test_generate_compile_model_forward(self):
337337
pass
338338

339339

0 commit comments

Comments
 (0)