File tree Expand file tree Collapse file tree 1 file changed +24
-1
lines changed 
tests/models/transformers Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Original file line number Diff line number Diff line change 1818import  torch 
1919
2020from  diffusers  import  LTXVideoTransformer3DModel 
21- from  diffusers .utils .testing_utils  import  enable_full_determinism , torch_device 
21+ from  diffusers .utils .testing_utils  import  (
22+     enable_full_determinism ,
23+     is_torch_compile ,
24+     require_torch_2 ,
25+     require_torch_gpu ,
26+     slow ,
27+     torch_device ,
28+ )
2229
2330from  ..test_modeling_common  import  ModelTesterMixin 
2431
@@ -81,3 +88,19 @@ def prepare_init_args_and_inputs_for_common(self):
8188    def  test_gradient_checkpointing_is_applied (self ):
8289        expected_set  =  {"LTXVideoTransformer3DModel" }
8390        super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
91+ 
92+     @require_torch_gpu  
93+     @require_torch_2  
94+     @is_torch_compile  
95+     @slow  
96+     def  test_torch_compile_recompilation_and_graph_break (self ):
97+         torch ._dynamo .reset ()
98+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
99+ 
100+         model  =  self .model_class (** init_dict ).to (torch_device )
101+         model .eval ()
102+         model  =  torch .compile (model , fullgraph = True )
103+ 
104+         with  torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
105+             _  =  model (** inputs_dict )
106+             _  =  model (** inputs_dict )
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments