Skip to content

Commit 7716303

Browse files
committed
tigher tests.
1 parent a9f5088 commit 7716303

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

tests/models/test_modeling_common.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,10 +1060,10 @@ def test_deprecated_kwargs(self):
10601060
" from `_deprecated_kwargs = [<deprecated_argument>]`"
10611061
)
10621062

1063-
@parameterized.expand([True, False])
1063+
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
10641064
@torch.no_grad()
10651065
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
1066-
def test_save_load_lora_adapter(self, use_dora=False):
1066+
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
10671067
from peft import LoraConfig
10681068
from peft.utils import get_peft_model_state_dict
10691069

@@ -1079,8 +1079,8 @@ def test_save_load_lora_adapter(self, use_dora=False):
10791079
output_no_lora = model(**inputs_dict, return_dict=False)[0]
10801080

10811081
denoiser_lora_config = LoraConfig(
1082-
r=4,
1083-
lora_alpha=4,
1082+
r=rank,
1083+
lora_alpha=lora_alpha,
10841084
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
10851085
init_lora_weights=False,
10861086
use_dora=use_dora,
@@ -1147,12 +1147,12 @@ def test_wrong_adapter_name_raises_error(self):
11471147

11481148
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
11491149

1150+
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
11501151
@torch.no_grad()
11511152
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
1152-
def test_adapter_metadata_is_loaded_correctly(self):
1153+
def test_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
11531154
from peft import LoraConfig
11541155

1155-
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
11561156
from diffusers.loaders.peft import PeftAdapterMixin
11571157

11581158
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
@@ -1162,11 +1162,11 @@ def test_adapter_metadata_is_loaded_correctly(self):
11621162
return
11631163

11641164
denoiser_lora_config = LoraConfig(
1165-
r=4,
1166-
lora_alpha=4,
1165+
r=rank,
1166+
lora_alpha=lora_alpha,
11671167
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
11681168
init_lora_weights=False,
1169-
use_dora=False,
1169+
use_dora=use_dora,
11701170
)
11711171
model.add_adapter(denoiser_lora_config)
11721172
metadata = model.peft_config["default"].to_dict()
@@ -1177,15 +1177,12 @@ def test_adapter_metadata_is_loaded_correctly(self):
11771177
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
11781178
self.assertTrue(os.path.isfile(model_file))
11791179

1180-
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
1181-
if hasattr(f, "metadata"):
1182-
parsed_metadata = f.metadata()
1183-
parsed_metadata = {k: v for k, v in parsed_metadata.items() if k != "format"}
1184-
self.assertTrue(LORA_ADAPTER_METADATA_KEY in parsed_metadata)
1185-
parsed_metadata = {k: v for k, v in parsed_metadata.items() if k != "format"}
1180+
model.unload_lora()
1181+
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
11861182

1187-
parsed_metadata = json.loads(parsed_metadata[LORA_ADAPTER_METADATA_KEY])
1188-
check_if_dicts_are_equal(parsed_metadata, metadata)
1183+
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
1184+
parsed_metadata = model.peft_config["default_0"].to_dict()
1185+
check_if_dicts_are_equal(metadata, parsed_metadata)
11891186

11901187
@require_torch_accelerator
11911188
def test_cpu_offload(self):

0 commit comments

Comments
 (0)