1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15+ import  os 
1516import  sys 
1617import  tempfile 
1718import  unittest 
1819
1920import  numpy  as  np 
2021import  pytest 
22+ import  safetensors .torch 
2123import  torch 
24+ from  peft .utils  import  get_peft_model_state_dict 
2225from  PIL  import  Image 
2326from  transformers  import  AutoTokenizer , T5EncoderModel 
2427
@@ -163,6 +166,7 @@ def test_layerwise_casting_inference_denoiser(self):
163166    @require_peft_version_greater ("0.13.2" ) 
164167    def  test_lora_exclude_modules_wanvace (self ):
165168        scheduler_cls  =  self .scheduler_classes [0 ]
169+         exclude_module_name  =  "vace_blocks.0.proj_out" 
166170        components , text_lora_config , denoiser_lora_config  =  self .get_dummy_components (scheduler_cls )
167171        pipe  =  self .pipeline_class (** components ).to (torch_device )
168172        _ , _ , inputs  =  self .get_dummy_inputs (with_generator = False )
@@ -172,22 +176,34 @@ def test_lora_exclude_modules_wanvace(self):
172176
173177        # only supported for `denoiser` now 
174178        denoiser_lora_config .target_modules  =  ["proj_out" ]
175-         denoiser_lora_config .exclude_modules  =  ["vace_blocks.0.proj_out" ]
179+         denoiser_lora_config .exclude_modules  =  [exclude_module_name ]
176180        pipe , _  =  self .add_adapters_to_pipeline (
177181            pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config 
178182        )
183+         # The state dict shouldn't contain the modules to be excluded from LoRA. 
184+         state_dict_from_model  =  get_peft_model_state_dict (pipe .transformer , adapter_name = "default" )
185+         self .assertTrue (not  any (exclude_module_name  in  k  for  k  in  state_dict_from_model ))
186+         self .assertTrue (any ("proj_out"  in  k  for  k  in  state_dict_from_model ))
179187        output_lora_exclude_modules  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
180188
181189        with  tempfile .TemporaryDirectory () as  tmpdir :
182190            modules_to_save  =  self ._get_modules_to_save (pipe , has_denoiser = True )
183191            lora_state_dicts  =  self ._get_lora_state_dicts (modules_to_save )
184-             lora_metadatas  =  self ._get_lora_adapter_metadata (modules_to_save )
185-             self .pipeline_class .save_lora_weights (save_directory = tmpdir , ** lora_state_dicts , ** lora_metadatas )
192+             self .pipeline_class .save_lora_weights (save_directory = tmpdir , ** lora_state_dicts )
186193            pipe .unload_lora_weights ()
194+ 
195+             # Check in the loaded state dict. 
196+             loaded_state_dict  =  safetensors .torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors" ))
197+             self .assertTrue (not  any (exclude_module_name  in  k  for  k  in  loaded_state_dict ))
198+             self .assertTrue (any ("proj_out"  in  k  for  k  in  loaded_state_dict ))
199+ 
200+             # Check in the state dict obtained after loading LoRA. 
187201            pipe .load_lora_weights (tmpdir )
202+             state_dict_from_model  =  get_peft_model_state_dict (pipe .transformer , adapter_name = "default_0" )
203+             self .assertTrue (not  any (exclude_module_name  in  k  for  k  in  state_dict_from_model ))
204+             self .assertTrue (any ("proj_out"  in  k  for  k  in  state_dict_from_model ))
188205
189206            output_lora_pretrained  =  pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
190- 
191207            self .assertTrue (
192208                not  np .allclose (output_no_lora , output_lora_exclude_modules , atol = 1e-3 , rtol = 1e-3 ),
193209                "LoRA should change outputs." ,
0 commit comments