Skip to content

Commit 102e98d

Browse files
committed
update
1 parent 763dba5 commit 102e98d

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

tests/lora/test_lora_layers_wanvace.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import sys
1617
import tempfile
1718
import unittest
1819

1920
import numpy as np
2021
import pytest
22+
import safetensors.torch
2123
import torch
24+
from peft.utils import get_peft_model_state_dict
2225
from PIL import Image
2326
from transformers import AutoTokenizer, T5EncoderModel
2427

@@ -163,6 +166,7 @@ def test_layerwise_casting_inference_denoiser(self):
163166
@require_peft_version_greater("0.13.2")
164167
def test_lora_exclude_modules_wanvace(self):
165168
scheduler_cls = self.scheduler_classes[0]
169+
exclude_module_name = "vace_blocks.0.proj_out"
166170
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
167171
pipe = self.pipeline_class(**components).to(torch_device)
168172
_, _, inputs = self.get_dummy_inputs(with_generator=False)
@@ -172,22 +176,34 @@ def test_lora_exclude_modules_wanvace(self):
172176

173177
# only supported for `denoiser` now
174178
denoiser_lora_config.target_modules = ["proj_out"]
175-
denoiser_lora_config.exclude_modules = ["vace_blocks.0.proj_out"]
179+
denoiser_lora_config.exclude_modules = [exclude_module_name]
176180
pipe, _ = self.add_adapters_to_pipeline(
177181
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
178182
)
183+
# The state dict shouldn't contain the modules to be excluded from LoRA.
184+
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
185+
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
186+
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
179187
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
180188

181189
with tempfile.TemporaryDirectory() as tmpdir:
182190
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
183191
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
184-
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
185-
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
192+
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
186193
pipe.unload_lora_weights()
194+
195+
# Check in the loaded state dict.
196+
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
197+
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
198+
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
199+
200+
# Check in the state dict obtained after loading LoRA.
187201
pipe.load_lora_weights(tmpdir)
202+
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
203+
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
204+
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
188205

189206
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
190-
191207
self.assertTrue(
192208
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
193209
"LoRA should change outputs.",

0 commit comments

Comments
 (0)