Skip to content

Commit be1d788

Browse files
committed
add a test for varied lora ranks and alphas.
1 parent 118ed9b commit be1d788

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def test_with_alpha_in_state_dict(self):
159159
)
160160
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
161161

162+
# flux control lora specific
162163
def test_with_norm_in_state_dict(self):
163164
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
164165
pipe = self.pipeline_class(**components)
@@ -210,6 +211,7 @@ def test_with_norm_in_state_dict(self):
210211
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
211212
)
212213

214+
# flux control lora specific
213215
def test_lora_parameter_expanded_shapes(self):
214216
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
215217
pipe = self.pipeline_class(**components)
@@ -254,6 +256,7 @@ def test_lora_parameter_expanded_shapes(self):
254256
with self.assertRaises(NotImplementedError):
255257
pipe.load_lora_weights(lora_state_dict, "adapter-1")
256258

259+
# flux control lora specific
257260
@require_peft_version_greater("0.13.2")
258261
def test_lora_B_bias(self):
259262
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -275,12 +278,53 @@ def test_lora_B_bias(self):
275278

276279
denoiser_lora_config.lora_bias = True
277280
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
278-
lora_bias_true_output = pipe(**inputs)[0]
281+
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
279282

280283
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
281284
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
282285
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
283286

287+
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
288+
def test_correct_lora_configs_with_different_ranks(self):
289+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
290+
pipe = self.pipeline_class(**components)
291+
pipe = pipe.to(torch_device)
292+
pipe.set_progress_bar_config(disable=None)
293+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
294+
295+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
296+
297+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
298+
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
299+
pipe.transformer.delete_adapters("adapter-1")
300+
301+
# change the rank_pattern
302+
updated_rank = denoiser_lora_config.r * 2
303+
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
304+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
305+
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
306+
"single_transformer_blocks.0.attn.to_k": updated_rank
307+
}
308+
309+
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
310+
311+
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
312+
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
313+
pipe.transformer.delete_adapters("adapter-1")
314+
315+
# similarly change the alpha_pattern
316+
updated_alpha = denoiser_lora_config.lora_alpha * 2
317+
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
318+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
319+
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
320+
"single_transformer_blocks.0.attn.to_k": updated_alpha
321+
}
322+
323+
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
324+
325+
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
326+
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
327+
284328
@unittest.skip("Not supported in Flux.")
285329
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
286330
pass

0 commit comments

Comments
 (0)