Skip to content

Commit ead9e0c

Browse files
committed
fix
1 parent c00e350 commit ead9e0c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/lora/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1952,12 +1952,15 @@ def test_set_adapters_match_attention_kwargs(self):
19521952
)
19531953

19541954
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
1955-
pipe.unload_lora_weights()
1955+
pipe = self.pipeline_class(**components)
1956+
pipe = pipe.to(torch_device)
1957+
pipe.set_progress_bar_config(disable=None)
19561958
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
19571959

19581960
for module_name, module in modules_to_save.items():
19591961
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
19601962

1963+
print(f"{attention_kwargs=}")
19611964
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
19621965
self.assertTrue(
19631966
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),

0 commit comments

Comments
 (0)