@@ -154,21 +154,31 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
154154
155155        return  inputs 
156156
157-     def  get_dummy_tensor_inputs (self , device = None ):
157+     def  get_dummy_tensor_inputs (self , device = None ,  seed :  int   =   0 ):
158158        batch_size  =  1 
159159        num_latent_channels  =  4 
160160        num_image_channels  =  3 
161161        height  =  width  =  4 
162162        sequence_length  =  48 
163163        embedding_dim  =  32 
164164
165+         torch .manual_seed (seed )
165166        hidden_states  =  torch .randn ((batch_size , height  *  width , num_latent_channels )).to (device , dtype = torch .bfloat16 )
167+ 
168+         torch .manual_seed (seed )
166169        encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , embedding_dim )).to (
167170            device , dtype = torch .bfloat16 
168171        )
172+ 
173+         torch .manual_seed (seed )
169174        pooled_prompt_embeds  =  torch .randn ((batch_size , embedding_dim )).to (device , dtype = torch .bfloat16 )
175+ 
176+         torch .manual_seed (seed )
170177        text_ids  =  torch .randn ((sequence_length , num_image_channels )).to (device , dtype = torch .bfloat16 )
178+ 
179+         torch .manual_seed (seed )
171180        image_ids  =  torch .randn ((height  *  width , num_image_channels )).to (device , dtype = torch .bfloat16 )
181+ 
172182        timestep  =  torch .tensor ([1.0 ]).to (device , dtype = torch .bfloat16 ).expand (batch_size )
173183
174184        return  {
@@ -322,6 +332,22 @@ def test_training(self):
322332                self .assertTrue (module .adapter [1 ].weight .grad  is  not None )
323333                self .assertTrue (module .adapter [1 ].weight .grad .norm ().item () >  0 )
324334
335+     def  test_torch_compile (self ):
336+         quantization_config  =  TorchAoConfig ("int8_weight_only" )
337+         components  =  self .get_dummy_components (quantization_config )
338+         pipe  =  FluxPipeline (** components )
339+         pipe .to (device = torch_device , dtype = torch .bfloat16 )
340+ 
341+         inputs  =  self .get_dummy_inputs (torch_device )
342+         normal_output  =  pipe (** inputs )[0 ].flatten ()[- 32 :]
343+ 
344+         pipe .transformer  =  torch .compile (pipe .transformer , mode = "max-autotune" , fullgraph = True , dynamic = False )
345+         inputs  =  self .get_dummy_inputs (torch_device )
346+         compile_output  =  pipe (** inputs )[0 ].flatten ()[- 32 :]
347+ 
348+         # Note: Seems to require higher tolerance 
349+         self .assertTrue (np .allclose (normal_output , compile_output , atol = 1e-2 , rtol = 1e-3 ))
350+ 
325351
326352@require_torch  
327353@require_torch_gpu  
0 commit comments