File tree Expand file tree Collapse file tree 3 files changed +31
-19
lines changed Expand file tree Collapse file tree 3 files changed +31
-19
lines changed Original file line number Diff line number Diff line change @@ -1714,6 +1714,35 @@ def test_push_to_hub_library_name(self):
17141714        delete_repo (self .repo_id , token = TOKEN )
17151715
17161716
1717+ class  TorchCompileTesterMixin :
1718+     def  setUp (self ):
1719+         # clean up the VRAM before each test 
1720+         super ().setUp ()
1721+         torch ._dynamo .reset ()
1722+         gc .collect ()
1723+         backend_empty_cache (torch_device )
1724+ 
1725+     def  tearDown (self ):
1726+         # clean up the VRAM after each test in case of CUDA runtime errors 
1727+         super ().tearDown ()
1728+         torch ._dynamo .reset ()
1729+         gc .collect ()
1730+         backend_empty_cache (torch_device )
1731+ 
1732+     @require_torch_gpu  
1733+     @require_torch_2  
1734+     @slow  
1735+     def  test_torch_compile_recompilation_and_graph_break (self ):
1736+         torch ._dynamo .reset ()
1737+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1738+ 
1739+         model  =  self .model_class (** init_dict ).to (torch_device )
1740+         model  =  torch .compile (model , fullgraph = True )
1741+ 
1742+         with  torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
1743+             _  =  model (** inputs_dict )
1744+ 
1745+ 
17171746@slow  
17181747@require_torch_2  
17191748@require_torch_accelerator  
Original file line number Diff line number Diff line change 2222from  diffusers .models .embeddings  import  ImageProjection 
2323from  diffusers .utils .testing_utils  import  enable_full_determinism , torch_device 
2424
25- from  ..test_modeling_common  import  ModelTesterMixin 
25+ from  ..test_modeling_common  import  ModelTesterMixin ,  TorchCompileTesterMixin 
2626
2727
2828enable_full_determinism ()
@@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
7878    return  ip_state_dict 
7979
8080
81- class  FluxTransformerTests (ModelTesterMixin , unittest .TestCase ):
81+ class  FluxTransformerTests (ModelTesterMixin , TorchCompileTesterMixin ,  unittest .TestCase ):
8282    model_class  =  FluxTransformer2DModel 
8383    main_input_name  =  "hidden_states" 
8484    # We override the items here because the transformer under consideration is small. 
Original file line number Diff line number Diff line change 5656    require_torch_gpu ,
5757    require_transformers_version_greater ,
5858    skip_mps ,
59-     slow ,
6059    torch_device ,
6160)
6261
@@ -2165,22 +2164,6 @@ def test_StableDiffusionMixin_component(self):
21652164            )
21662165        )
21672166
2168-     @require_torch_gpu  
2169-     @slow  
2170-     def  test_torch_compile_recompilation_and_graph_break (self ):
2171-         torch ._dynamo .reset ()
2172-         inputs  =  self .get_dummy_inputs (torch_device )
2173-         components  =  self .get_dummy_components ()
2174- 
2175-         pipe  =  self .pipeline_class (** components ).to (torch_device )
2176-         if  getattr (pipe , "unet" , None ) is  not   None :
2177-             pipe .unet  =  torch .compile (pipe .unet , fullgraph = True )
2178-         else :
2179-             pipe .transformer  =  torch .compile (pipe .transformer , fullgraph = True )
2180- 
2181-         with  torch ._dynamo .config .patch (error_on_recompile = True ):
2182-             _  =  pipe (** inputs )
2183- 
21842167    @require_hf_hub_version_greater ("0.26.5" ) 
21852168    @require_transformers_version_greater ("4.47.1" ) 
21862169    def  test_save_load_dduf (self , atol = 1e-4 , rtol = 1e-4 ):
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments