Skip to content

Commit b12f797

Browse files
committed
updates for tests
1 parent a8c50ba commit b12f797

File tree

1 file changed

+15
-45
lines changed

1 file changed

+15
-45
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,6 @@ def get_dummy_inputs(self, with_generator=True):
111111

112112
return noise, input_ids, pipeline_inputs
113113

114-
def get_dummy_tensor_inputs(self, device=None):
115-
batch_size = 1
116-
num_latent_channels = 4
117-
num_image_channels = 3
118-
height = width = 4
119-
sequence_length = 48
120-
embedding_dim = 32
121-
122-
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
123-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
124-
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
125-
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
126-
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
127-
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
128-
129-
return {
130-
"hidden_states": hidden_states,
131-
"encoder_hidden_states": encoder_hidden_states,
132-
"pooled_projections": pooled_prompt_embeds,
133-
"txt_ids": text_ids,
134-
"img_ids": image_ids,
135-
"timestep": timestep,
136-
}
137-
138114
def test_with_alpha_in_state_dict(self):
139115
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
140116
pipe = self.pipeline_class(**components)
@@ -189,13 +165,12 @@ def test_with_norm_in_state_dict(self):
189165
pipe = pipe.to(torch_device)
190166
pipe.set_progress_bar_config(disable=None)
191167

192-
inputs = self.get_dummy_tensor_inputs(torch_device)
168+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
193169

194170
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
195171
logger.setLevel(logging.INFO)
196172

197-
with torch.no_grad():
198-
original_output = pipe.transformer(**inputs)[0]
173+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
199174

200175
for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]:
201176
norm_state_dict = {}
@@ -206,18 +181,19 @@ def test_with_norm_in_state_dict(self):
206181
module.weight.shape, device=module.weight.device, dtype=module.weight.dtype
207182
)
208183

209-
with torch.no_grad():
210184
with CaptureLogger(logger) as cap_logger:
211185
pipe.load_lora_weights(norm_state_dict)
212-
lora_load_output = pipe.transformer(**inputs)[0]
186+
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
187+
213188
self.assertTrue(
214189
cap_logger.out.startswith(
215190
"The provided state dict contains normalization layers in addition to LoRA layers"
216191
)
217192
)
193+
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
218194

219195
pipe.unload_lora_weights()
220-
lora_unload_output = pipe.transformer(**inputs)[0]
196+
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
221197

222198
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
223199
self.assertFalse(np.allclose(original_output, lora_load_output, atol=1e-5, rtol=1e-5))
@@ -238,14 +214,11 @@ def test_lora_parameter_expanded_shapes(self):
238214
pipe = pipe.to(torch_device)
239215
pipe.set_progress_bar_config(disable=None)
240216

241-
inputs = self.get_dummy_tensor_inputs(torch_device)
217+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
242218

243219
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
244220
logger.setLevel(logging.DEBUG)
245221

246-
with torch.no_grad():
247-
original_output = pipe.transformer(**inputs)[0]
248-
249222
out_features, in_features = pipe.transformer.x_embedder.weight.shape
250223
rank = 4
251224

@@ -257,12 +230,12 @@ def test_lora_parameter_expanded_shapes(self):
257230
}
258231
with CaptureLogger(logger) as cap_logger:
259232
pipe.load_lora_weights(lora_state_dict, "adapter-1")
260-
inputs["hidden_states"] = torch.cat([inputs["hidden_states"]] * 2, dim=2)
261-
with torch.no_grad():
262-
expanded_output = pipe.transformer(**inputs)[0]
233+
234+
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
235+
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
236+
263237
pipe.delete_adapters("adapter-1")
264238
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
265-
self.assertFalse(np.allclose(original_output, expanded_output, atol=1e-3, rtol=1e-3))
266239

267240
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
268241
pipe = self.pipeline_class(**components)
@@ -286,24 +259,21 @@ def test_lora_B_bias(self):
286259
pipe = pipe.to(torch_device)
287260
pipe.set_progress_bar_config(disable=None)
288261

289-
inputs = self.get_dummy_tensor_inputs(torch_device)
262+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
290263

291264
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
292265
logger.setLevel(logging.INFO)
293266

294-
with torch.no_grad():
295-
original_output = pipe.transformer(**inputs)[0]
267+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
296268

297269
denoiser_lora_config.lora_bias = False
298270
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
299-
with torch.no_grad():
300-
lora_bias_false_output = pipe.transformer(**inputs)[0]
271+
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
301272
pipe.delete_adapters("adapter-1")
302273

303274
denoiser_lora_config.lora_bias = True
304275
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
305-
with torch.no_grad():
306-
lora_bias_true_output = pipe.transformer(**inputs)[0]
276+
lora_bias_true_output = pipe(**inputs)[0]
307277

308278
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
309279
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)