@@ -2098,3 +2098,43 @@ def test_correct_lora_configs_with_different_ranks(self):
20982098        lora_output_diff_alpha  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
20992099        self .assertTrue (not  np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
21002100        self .assertTrue (not  np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
2101+ 
2102+     def  test_layerwise_upcasting_inference_denoiser (self ):
2103+         def  initialize_pipeline (storage_dtype = None , compute_dtype = torch .float32 ):
2104+             components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (self .scheduler_classes [0 ])
2105+             pipe  =  self .pipeline_class (** components )
2106+             pipe  =  pipe .to (torch_device , dtype = compute_dtype )
2107+             pipe .set_progress_bar_config (disable = None )
2108+ 
2109+             if  "text_encoder"  in  self .pipeline_class ._lora_loadable_modules :
2110+                 pipe .text_encoder .add_adapter (text_lora_config )
2111+                 self .assertTrue (
2112+                     check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" 
2113+                 )
2114+ 
2115+             denoiser  =  pipe .transformer  if  self .unet_kwargs  is  None  else  pipe .unet 
2116+             denoiser .add_adapter (denoiser_lora_config )
2117+             self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2118+ 
2119+             if  self .has_two_text_encoders  or  self .has_three_text_encoders :
2120+                 if  "text_encoder_2"  in  self .pipeline_class ._lora_loadable_modules :
2121+                     pipe .text_encoder_2 .add_adapter (text_lora_config )
2122+                     self .assertTrue (
2123+                         check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2" 
2124+                     )
2125+ 
2126+             if  storage_dtype  is  not None :
2127+                 denoiser .enable_layerwise_upcasting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
2128+ 
2129+             return  pipe 
2130+ 
2131+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
2132+ 
2133+         pipe_fp32  =  initialize_pipeline (storage_dtype = None )
2134+         pipe_fp32 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2135+ 
2136+         pipe_float8_e4m3_fp32  =  initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .float32 )
2137+         pipe_float8_e4m3_fp32 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2138+ 
2139+         pipe_float8_e4m3_bf16  =  initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .bfloat16 )
2140+         pipe_float8_e4m3_bf16 (** inputs , generator = torch .manual_seed (0 ))[0 ]
0 commit comments