Skip to content

Commit e70f265

Browse files
committed
fixes
1 parent 5694f11 commit e70f265

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

tests/lora/utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,18 +1784,15 @@ def test_missing_keys_warning(self):
17841784
missing_key = [k for k in state_dict if "lora_A" in k][0]
17851785
del state_dict[missing_key]
17861786

1787-
logger = (
1788-
logging.get_logger("diffusers.loaders.unet")
1789-
if self.unet_kwargs is not None
1790-
else logging.get_logger("diffusers.loaders.peft")
1791-
)
1787+
logger = logging.get_logger("diffusers.loaders.peft")
17921788
logger.setLevel(30)
17931789
with CaptureLogger(logger) as cap_logger:
17941790
pipe.load_lora_weights(state_dict)
17951791

17961792
# Since the missing key won't contain the adapter name ("default_0").
17971793
# Also strip out the component prefix (such as "unet." from `missing_key`).
17981794
component = list({k.split(".")[0] for k in state_dict})[0]
1795+
print(f"{cap_logger.out=}")
17991796
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
18001797

18011798
def test_unexpected_keys_warning(self):
@@ -1823,11 +1820,7 @@ def test_unexpected_keys_warning(self):
18231820
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
18241821
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
18251822

1826-
logger = (
1827-
logging.get_logger("diffusers.loaders.unet")
1828-
if self.unet_kwargs is not None
1829-
else logging.get_logger("diffusers.loaders.peft")
1830-
)
1823+
logger = logging.get_logger("diffusers.loaders.peft")
18311824
logger.setLevel(30)
18321825
with CaptureLogger(logger) as cap_logger:
18331826
pipe.load_lora_weights(state_dict)

0 commit comments

Comments
 (0)