Skip to content

Commit 4082c43

Browse files
committed
Fix dummy input preparation and fix some test bugs
1 parent f6c82a3 commit 4082c43

File tree

1 file changed

+43
-26
lines changed

1 file changed

+43
-26
lines changed

tests/models/transformers/test_models_transformer_flux2.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,26 @@ def output_shape(self):
104104
def prepare_dummy_input(self, height=4, width=4):
105105
batch_size = 1
106106
num_latent_channels = 4
107-
num_image_channels = 3
108107
sequence_length = 48
109108
embedding_dim = 32
110109

111110
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
112111
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
113-
# pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
114-
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
115-
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
112+
113+
t_coords = torch.arange(1)
114+
h_coords = torch.arange(height)
115+
w_coords = torch.arange(width)
116+
l_coords = torch.arange(1)
117+
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4]
118+
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
119+
120+
text_t_coords = torch.arange(1)
121+
text_h_coords = torch.arange(1)
122+
text_w_coords = torch.arange(1)
123+
text_l_coords = torch.arange(sequence_length)
124+
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
125+
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
126+
116127
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
117128
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
118129

@@ -135,44 +146,50 @@ def prepare_init_args_and_inputs_for_common(self):
135146
"attention_head_dim": 16,
136147
"num_attention_heads": 2,
137148
"joint_attention_dim": 32,
138-
# "pooled_projection_dim": 32,
139-
"timestep_guidance_channels": 16,
140-
"axes_dims_rope": [4, 4, 8],
149+
"timestep_guidance_channels": 256, # Hardcoded in original code
150+
"axes_dims_rope": [4, 4, 4, 4],
141151
}
142152

143153
inputs_dict = self.dummy_input
144154
return init_dict, inputs_dict
145155

146-
def test_deprecated_inputs_img_txt_ids_3d(self):
156+
def test_flux2_consistency(self, seed=0):
157+
torch.manual_seed(seed)
147158
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
159+
160+
torch.manual_seed(seed)
148161
model = self.model_class(**init_dict)
162+
# state_dict = model.state_dict()
163+
# for key, param in state_dict.items():
164+
# print(f"{key} | {param.shape}")
165+
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
149166
model.to(torch_device)
150167
model.eval()
151168

152169
with torch.no_grad():
153-
output_1 = model(**inputs_dict).to_tuple()[0]
170+
output = model(**inputs_dict)
154171

155-
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
156-
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
157-
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
172+
if isinstance(output, dict):
173+
output = output.to_tuple()[0]
158174

159-
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
160-
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
175+
self.assertIsNotNone(output)
161176

162-
inputs_dict["txt_ids"] = text_ids_3d
163-
inputs_dict["img_ids"] = image_ids_3d
177+
# input & output have to have the same shape
178+
input_tensor = inputs_dict[self.main_input_name]
179+
expected_shape = input_tensor.shape
180+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
164181

165-
with torch.no_grad():
166-
output_2 = model(**inputs_dict).to_tuple()[0]
182+
# Check against expected slice
183+
# fmt: off
184+
expected_slice = torch.tensor([-0.3180, 0.4818, 0.6621, -0.3386, 0.2313, 0.0688, 0.0985, -0.2686, -0.1480, -0.1607, -0.7245, 0.5385, -0.2842, 0.6575, -0.0697, 0.4951])
185+
# fmt: on
167186

168-
self.assertEqual(output_1.shape, output_2.shape)
169-
self.assertTrue(
170-
torch.allclose(output_1, output_2, atol=1e-5),
171-
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
172-
)
187+
flat_output = output.cpu().flatten()
188+
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
189+
self.assertTrue(torch.allclose(expected_slice, generated_slice))
173190

174191
def test_gradient_checkpointing_is_applied(self):
175-
expected_set = {"FluxTransformer2DModel"}
192+
expected_set = {"Flux2Transformer2DModel"}
176193
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
177194

178195
# The test exists for cases like
@@ -205,7 +222,7 @@ def test_lora_exclude_modules(self):
205222
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
206223

207224

208-
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
225+
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
209226
model_class = Flux2Transformer2DModel
210227
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
211228

@@ -216,7 +233,7 @@ def prepare_dummy_input(self, height, width):
216233
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
217234

218235

219-
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
236+
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
220237
model_class = Flux2Transformer2DModel
221238
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
222239

0 commit comments

Comments
 (0)