@@ -2157,3 +2157,64 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21572157
21582158        pipe_float8_e4m3_bf16  =  initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .bfloat16 )
21592159        pipe_float8_e4m3_bf16 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2160+ 
2161+     @require_peft_version_greater ("0.14.0" ) 
2162+     def  test_layerwise_casting_peft_input_autocast_denoiser (self ):
2163+         r""" 
2164+         A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This 
2165+         is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise 
2166+         cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`). 
2167+         In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version < 0.14.0, 
2168+         this test will fail with the following error: 
2169+ 
2170+         ``` 
2171+         RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float 
2172+         ``` 
2173+ 
2174+         See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. 
2175+         """ 
2176+ 
2177+         from  diffusers .hooks .layerwise_casting  import  (
2178+             _PEFT_AUTOCAST_DISABLE_HOOK ,
2179+             DEFAULT_SKIP_MODULES_PATTERN ,
2180+             SUPPORTED_PYTORCH_LAYERS ,
2181+             apply_layerwise_casting ,
2182+         )
2183+ 
2184+         storage_dtype  =  torch .float8_e4m3fn 
2185+         compute_dtype  =  torch .float32 
2186+ 
2187+         components , _ , denoiser_lora_config  =  self .get_dummy_components (self .scheduler_classes [0 ])
2188+         pipe  =  self .pipeline_class (** components )
2189+         pipe  =  pipe .to (torch_device , dtype = compute_dtype )
2190+         pipe .set_progress_bar_config (disable = None )
2191+ 
2192+         denoiser  =  pipe .transformer  if  self .unet_kwargs  is  None  else  pipe .unet 
2193+         denoiser .add_adapter (denoiser_lora_config )
2194+         self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2195+ 
2196+         patterns_to_check  =  DEFAULT_SKIP_MODULES_PATTERN 
2197+         if  getattr (denoiser , "_skip_layerwise_casting_patterns" , None ) is  not None :
2198+             patterns_to_check  +=  tuple (denoiser ._skip_layerwise_casting_patterns )
2199+ 
2200+         apply_layerwise_casting (
2201+             denoiser , storage_dtype = storage_dtype , compute_dtype = compute_dtype , skip_modules_pattern = patterns_to_check 
2202+         )
2203+ 
2204+         # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) 
2205+         for  name , module  in  denoiser .named_modules ():
2206+             if  not  isinstance (module , SUPPORTED_PYTORCH_LAYERS ):
2207+                 continue 
2208+             dtype_to_check  =  storage_dtype 
2209+             if  any (re .search (pattern , name ) for  pattern  in  patterns_to_check ):
2210+                 dtype_to_check  =  compute_dtype 
2211+             if  getattr (module , "weight" , None ) is  not None :
2212+                 self .assertEqual (module .weight .dtype , dtype_to_check )
2213+             if  getattr (module , "bias" , None ) is  not None :
2214+                 self .assertEqual (module .bias .dtype , dtype_to_check )
2215+             if  isinstance (module , BaseTunerLayer ):
2216+                 self .assertTrue (getattr (module , "_diffusers_hook" , None ) is  not None )
2217+                 self .assertTrue (module ._diffusers_hook .get_hook (_PEFT_AUTOCAST_DISABLE_HOOK ) is  not None )
2218+ 
2219+         _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
2220+         pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
0 commit comments