Skip to content

Commit 7a002e2

Browse files
committed
single scheduler please.
1 parent 5e181ed commit 7a002e2

File tree

4 files changed

+1031
-1115
lines changed

4 files changed

+1031
-1115
lines changed

tests/lora/test_lora_layers_cogview4.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -124,30 +124,29 @@ def test_simple_inference_save_pretrained(self):
124124
"""
125125
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
126126
"""
127-
for scheduler_cls in self.scheduler_classes:
128-
components, _, _ = self.get_dummy_components(scheduler_cls)
129-
pipe = self.pipeline_class(**components)
130-
pipe = pipe.to(torch_device)
131-
pipe.set_progress_bar_config(disable=None)
132-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
127+
components, _, _ = self.get_dummy_components()
128+
pipe = self.pipeline_class(**components)
129+
pipe = pipe.to(torch_device)
130+
pipe.set_progress_bar_config(disable=None)
131+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
133132

134-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
135-
self.assertTrue(output_no_lora.shape == self.output_shape)
133+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
134+
self.assertTrue(output_no_lora.shape == self.output_shape)
136135

137-
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
136+
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
138137

139-
with tempfile.TemporaryDirectory() as tmpdirname:
140-
pipe.save_pretrained(tmpdirname)
138+
with tempfile.TemporaryDirectory() as tmpdirname:
139+
pipe.save_pretrained(tmpdirname)
141140

142-
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
143-
pipe_from_pretrained.to(torch_device)
141+
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
142+
pipe_from_pretrained.to(torch_device)
144143

145-
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
144+
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
146145

147-
self.assertTrue(
148-
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
149-
"Loading from saved checkpoints should give same results.",
150-
)
146+
self.assertTrue(
147+
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
148+
"Loading from saved checkpoints should give same results.",
149+
)
151150

152151
@parameterized.expand([("block_level", True), ("leaf_level", False)])
153152
@require_torch_accelerator

tests/lora/test_lora_layers_lumina2.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -141,33 +141,30 @@ def test_simple_inference_with_text_lora_save_load(self):
141141
strict=False,
142142
)
143143
def test_lora_fuse_nan(self):
144-
for scheduler_cls in self.scheduler_classes:
145-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
146-
pipe = self.pipeline_class(**components)
147-
pipe = pipe.to(torch_device)
148-
pipe.set_progress_bar_config(disable=None)
149-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
150-
151-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
152-
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
153-
self.assertTrue(
154-
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
155-
)
156-
157-
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
158-
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
159-
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
160-
161-
# corrupt one LoRA weight with `inf` values
162-
with torch.no_grad():
163-
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
164-
165-
# with `safe_fusing=True` we should see an Error
166-
with self.assertRaises(ValueError):
167-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
168-
169-
# without we should not see an error, but every image will be black
170-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
171-
out = pipe(**inputs)[0]
172-
173-
self.assertTrue(np.isnan(out).all())
144+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
145+
pipe = self.pipeline_class(**components)
146+
pipe = pipe.to(torch_device)
147+
pipe.set_progress_bar_config(disable=None)
148+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
149+
150+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
151+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
152+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
153+
154+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
155+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
156+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
157+
158+
# corrupt one LoRA weight with `inf` values
159+
with torch.no_grad():
160+
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
161+
162+
# with `safe_fusing=True` we should see an Error
163+
with self.assertRaises(ValueError):
164+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
165+
166+
# without we should not see an error, but every image will be black
167+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
168+
out = pipe(**inputs)[0]
169+
170+
self.assertTrue(np.isnan(out).all())

tests/lora/test_lora_layers_wanvace.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,8 @@ def test_layerwise_casting_inference_denoiser(self):
165165

166166
@require_peft_version_greater("0.13.2")
167167
def test_lora_exclude_modules_wanvace(self):
168-
scheduler_cls = self.scheduler_classes[0]
169168
exclude_module_name = "vace_blocks.0.proj_out"
170-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
169+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
171170
pipe = self.pipeline_class(**components).to(torch_device)
172171
_, _, inputs = self.get_dummy_inputs(with_generator=False)
173172

0 commit comments

Comments
 (0)