File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed 
tests/quantization/quanto Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change 101101            mps_backend_registered  =  hasattr (torch .backends , "mps" )
102102            torch_device  =  "mps"  if  (mps_backend_registered  and  torch .backends .mps .is_available ()) else  torch_device 
103103
104+     from  .torch_utils  import  get_torch_cuda_device_capability 
105+ 
104106
105107def  torch_all_close (a , b , * args , ** kwargs ):
106108    if  not  is_torch_available ():
@@ -282,6 +284,18 @@ def require_torch_gpu(test_case):
282284    )
283285
284286
287+ def  require_torch_cuda_compatibility (expected_compute_capability ):
288+     def  decorator (test_case ):
289+         if  not  torch .cuda .is_available ():
290+             return  unittest .skip (test_case )
291+         else :
292+             compute_capability  =  get_torch_cuda_device_capability ()
293+             current_compute_capability  =  f"{ compute_capability [0 ]} { compute_capability [1 ]}  
294+             return  unittest .skipUnless (float (current_compute_capability ) ==  float (expected_compute_capability ))
295+ 
296+     return  decorator 
297+ 
298+ 
285299# These decorators are for accelerator-specific behaviours that are not GPU-specific 
286300def  require_torch_accelerator (test_case ):
287301    """Decorator marking a test that requires an accelerator backend and PyTorch.""" 
Original file line number Diff line number Diff line change 1010    numpy_cosine_similarity_distance ,
1111    require_accelerate ,
1212    require_big_gpu_with_torch_cuda ,
13+     require_torch_cuda_compatibility ,
1314    torch_device ,
1415)
1516
@@ -311,13 +312,19 @@ def get_dummy_init_kwargs(self):
311312        return  {"weights_dtype" : "int8" }
312313
313314
315+ require_torch_cuda_compatibility (8.0 )
316+ 
317+ 
314318class  FluxTransformerInt4WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
315319    expected_memory_reduction  =  0.55 
316320
317321    def  get_dummy_init_kwargs (self ):
318322        return  {"weights_dtype" : "int4" }
319323
320324
325+ require_torch_cuda_compatibility (8.0 )
326+ 
327+ 
321328class  FluxTransformerInt2WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
322329    expected_memory_reduction  =  0.65 
323330
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments