@@ -1784,18 +1784,15 @@ def test_missing_keys_warning(self):
17841784 missing_key = [k for k in state_dict if "lora_A" in k ][0 ]
17851785 del state_dict [missing_key ]
17861786
1787- logger = (
1788- logging .get_logger ("diffusers.loaders.unet" )
1789- if self .unet_kwargs is not None
1790- else logging .get_logger ("diffusers.loaders.peft" )
1791- )
1787+ logger = logging .get_logger ("diffusers.loaders.peft" )
17921788 logger .setLevel (30 )
17931789 with CaptureLogger (logger ) as cap_logger :
17941790 pipe .load_lora_weights (state_dict )
17951791
17961792 # Since the missing key won't contain the adapter name ("default_0").
17971793 # Also strip out the component prefix (such as "unet." from `missing_key`).
17981794 component = list ({k .split ("." )[0 ] for k in state_dict })[0 ]
1795+ print (f"{ cap_logger .out = } " )
17991796 self .assertTrue (missing_key .replace (f"{ component } ." , "" ) in cap_logger .out .replace ("default_0." , "" ))
18001797
18011798 def test_unexpected_keys_warning (self ):
@@ -1823,11 +1820,7 @@ def test_unexpected_keys_warning(self):
18231820 unexpected_key = [k for k in state_dict if "lora_A" in k ][0 ] + ".diffusers_cat"
18241821 state_dict [unexpected_key ] = torch .tensor (1.0 , device = torch_device )
18251822
1826- logger = (
1827- logging .get_logger ("diffusers.loaders.unet" )
1828- if self .unet_kwargs is not None
1829- else logging .get_logger ("diffusers.loaders.peft" )
1830- )
1823+ logger = logging .get_logger ("diffusers.loaders.peft" )
18311824 logger .setLevel (30 )
18321825 with CaptureLogger (logger ) as cap_logger :
18331826 pipe .load_lora_weights (state_dict )
0 commit comments