@@ -2042,16 +2042,10 @@ def test_generate_with_quant_cache(self):
20422042 with self .assertRaises (ValueError ):
20432043 model .generate (** generation_kwargs , ** inputs_dict )
20442044
2045- @parameterized .expand (
2046- [
2047- ("forward_only" , False ), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
2048- ("end_to_end" , True ), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
2049- ]
2050- )
20512045 @pytest .mark .generate
20522046 @require_torch_gpu
20532047 @slow
2054- def test_generate_compile (self , _ , end_to_end ):
2048+ def test_generate_compile_model_forward (self ):
20552049 """
20562050 Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
20572051 end-to-end compilation and forward pass compilation only.
@@ -2061,14 +2055,7 @@ def test_generate_compile(self, _, end_to_end):
20612055 if not model_class ._supports_static_cache :
20622056 self .skipTest ("This model doesn't support static cache" )
20632057
2064- # TODO (joao) -- fix and enable me :)
2065- if end_to_end and any (model_name in model_class .__name__ .lower () for model_name in ["whisper" ]):
2066- self .skipTest ("whisper model end-to-end generate compile not yet supported" )
2067-
20682058 config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
2069- # TODO (joao) -- fix and enable me :)
2070- if end_to_end and config .is_encoder_decoder :
2071- self .skipTest ("Encoder-decoder model end-to-end generate compile not yet supported" )
20722059
20732060 model = model_class (config ).to (torch_device )
20742061 model .eval () # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
@@ -2084,10 +2071,8 @@ def test_generate_compile(self, _, end_to_end):
20842071 "max_new_tokens" : 10 ,
20852072 "return_dict_in_generate" : True ,
20862073 "output_scores" : True ,
2074+ "cache_implementation" : "static" ,
20872075 }
2088- # end-to-end works best with dynamic cache, forward compilation works best with static cache
2089- if not end_to_end :
2090- generation_kwargs ["cache_implementation" ] = "static"
20912076
20922077 # get eager + dynamic cache results for future comparison
20932078 dynamic_outputs = []
@@ -2098,10 +2083,8 @@ def test_generate_compile(self, _, end_to_end):
20982083 generation_config = copy .deepcopy (model .generation_config )
20992084 generation_config .update (** generation_kwargs )
21002085 torch .compiler .reset ()
2101- if end_to_end :
2102- model .generate = torch .compile (model .generate , fullgraph = True , mode = "reduce-overhead" )
2103- else :
2104- model .forward = torch .compile (model .forward , fullgraph = True , mode = "reduce-overhead" )
2086+
2087+ model .forward = torch .compile (model .forward , fullgraph = True , mode = "reduce-overhead" )
21052088
21062089 compiled_outputs = []
21072090 for model_inputs in input_ids_sets :
0 commit comments