1414# limitations under the License. 
1515import  inspect 
1616import  os 
17+ import  re 
1718import  tempfile 
1819import  unittest 
1920from  itertools  import  product 
@@ -2100,6 +2101,23 @@ def test_correct_lora_configs_with_different_ranks(self):
21002101        self .assertTrue (not  np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
21012102
21022103    def  test_layerwise_upcasting_inference_denoiser (self ):
2104+         from  diffusers .hooks .layerwise_upcasting  import  DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS 
2105+ 
2106+         def  check_linear_dtype (module , storage_dtype , compute_dtype ):
2107+             patterns_to_check  =  DEFAULT_SKIP_MODULES_PATTERN 
2108+             if  getattr (module , "_precision_sensitive_module_patterns" , None ) is  not None :
2109+                 patterns_to_check  +=  tuple (module ._precision_sensitive_module_patterns )
2110+             for  name , submodule  in  module .named_modules ():
2111+                 if  not  isinstance (submodule , SUPPORTED_PYTORCH_LAYERS ):
2112+                     continue 
2113+                 dtype_to_check  =  storage_dtype 
2114+                 if  "lora"  in  name  or  any (re .search (pattern , name ) for  pattern  in  patterns_to_check ):
2115+                     dtype_to_check  =  compute_dtype 
2116+                 if  getattr (submodule , "weight" , None ) is  not None :
2117+                     self .assertEqual (submodule .weight .dtype , dtype_to_check )
2118+                 if  getattr (submodule , "bias" , None ) is  not None :
2119+                     self .assertEqual (submodule .bias .dtype , dtype_to_check )
2120+ 
21032121        def  initialize_pipeline (storage_dtype = None , compute_dtype = torch .float32 ):
21042122            components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (self .scheduler_classes [0 ])
21052123            pipe  =  self .pipeline_class (** components )
@@ -2125,6 +2143,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21252143
21262144            if  storage_dtype  is  not None :
21272145                denoiser .enable_layerwise_upcasting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
2146+                 check_linear_dtype (denoiser , storage_dtype , compute_dtype )
21282147
21292148            return  pipe 
21302149
0 commit comments