@@ -104,15 +104,26 @@ def output_shape(self):
104104 def prepare_dummy_input (self , height = 4 , width = 4 ):
105105 batch_size = 1
106106 num_latent_channels = 4
107- num_image_channels = 3
108107 sequence_length = 48
109108 embedding_dim = 32
110109
111110 hidden_states = torch .randn ((batch_size , height * width , num_latent_channels )).to (torch_device )
112111 encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
113- # pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
114- text_ids = torch .randn ((sequence_length , num_image_channels )).to (torch_device )
115- image_ids = torch .randn ((height * width , num_image_channels )).to (torch_device )
112+
113+ t_coords = torch .arange (1 )
114+ h_coords = torch .arange (height )
115+ w_coords = torch .arange (width )
116+ l_coords = torch .arange (1 )
117+ image_ids = torch .cartesian_prod (t_coords , h_coords , w_coords , l_coords ) # [height * width, 4]
118+ image_ids = image_ids .unsqueeze (0 ).expand (batch_size , - 1 , - 1 ).to (torch_device )
119+
120+ text_t_coords = torch .arange (1 )
121+ text_h_coords = torch .arange (1 )
122+ text_w_coords = torch .arange (1 )
123+ text_l_coords = torch .arange (sequence_length )
124+ text_ids = torch .cartesian_prod (text_t_coords , text_h_coords , text_w_coords , text_l_coords )
125+ text_ids = text_ids .unsqueeze (0 ).expand (batch_size , - 1 , - 1 ).to (torch_device )
126+
116127 timestep = torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size )
117128 guidance = torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size )
118129
@@ -135,44 +146,50 @@ def prepare_init_args_and_inputs_for_common(self):
135146 "attention_head_dim" : 16 ,
136147 "num_attention_heads" : 2 ,
137148 "joint_attention_dim" : 32 ,
138- # "pooled_projection_dim": 32,
139- "timestep_guidance_channels" : 16 ,
140- "axes_dims_rope" : [4 , 4 , 8 ],
149+ "timestep_guidance_channels" : 256 , # Hardcoded in original code
150+ "axes_dims_rope" : [4 , 4 , 4 , 4 ],
141151 }
142152
143153 inputs_dict = self .dummy_input
144154 return init_dict , inputs_dict
145155
146- def test_deprecated_inputs_img_txt_ids_3d (self ):
156+ def test_flux2_consistency (self , seed = 0 ):
157+ torch .manual_seed (seed )
147158 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
159+
160+ torch .manual_seed (seed )
148161 model = self .model_class (** init_dict )
162+ # state_dict = model.state_dict()
163+ # for key, param in state_dict.items():
164+ # print(f"{key} | {param.shape}")
165+ # torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
149166 model .to (torch_device )
150167 model .eval ()
151168
152169 with torch .no_grad ():
153- output_1 = model (** inputs_dict ). to_tuple ()[ 0 ]
170+ output = model (** inputs_dict )
154171
155- # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
156- text_ids_3d = inputs_dict ["txt_ids" ].unsqueeze (0 )
157- image_ids_3d = inputs_dict ["img_ids" ].unsqueeze (0 )
172+ if isinstance (output , dict ):
173+ output = output .to_tuple ()[0 ]
158174
159- assert text_ids_3d .ndim == 3 , "text_ids_3d should be a 3d tensor"
160- assert image_ids_3d .ndim == 3 , "img_ids_3d should be a 3d tensor"
175+ self .assertIsNotNone (output )
161176
162- inputs_dict ["txt_ids" ] = text_ids_3d
163- inputs_dict ["img_ids" ] = image_ids_3d
177+ # input & output have to have the same shape
178+ input_tensor = inputs_dict [self .main_input_name ]
179+ expected_shape = input_tensor .shape
180+ self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
164181
165- with torch .no_grad ():
166- output_2 = model (** inputs_dict ).to_tuple ()[0 ]
182+ # Check against expected slice
183+ # fmt: off
184+ expected_slice = torch .tensor ([- 0.3180 , 0.4818 , 0.6621 , - 0.3386 , 0.2313 , 0.0688 , 0.0985 , - 0.2686 , - 0.1480 , - 0.1607 , - 0.7245 , 0.5385 , - 0.2842 , 0.6575 , - 0.0697 , 0.4951 ])
185+ # fmt: on
167186
168- self .assertEqual (output_1 .shape , output_2 .shape )
169- self .assertTrue (
170- torch .allclose (output_1 , output_2 , atol = 1e-5 ),
171- msg = "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs" ,
172- )
187+ flat_output = output .cpu ().flatten ()
188+ generated_slice = torch .cat ([flat_output [:8 ], flat_output [- 8 :]])
189+ self .assertTrue (torch .allclose (expected_slice , generated_slice ))
173190
174191 def test_gradient_checkpointing_is_applied (self ):
175- expected_set = {"FluxTransformer2DModel " }
192+ expected_set = {"Flux2Transformer2DModel " }
176193 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
177194
178195 # The test exists for cases like
@@ -205,7 +222,7 @@ def test_lora_exclude_modules(self):
205222 assert (retrieved_lora_state_dict ["single_transformer_blocks.0.proj_out.lora_B.weight" ] == 33 ).all ()
206223
207224
208- class FluxTransformerCompileTests (TorchCompileTesterMixin , unittest .TestCase ):
225+ class Flux2TransformerCompileTests (TorchCompileTesterMixin , unittest .TestCase ):
209226 model_class = Flux2Transformer2DModel
210227 different_shapes_for_compilation = [(4 , 4 ), (4 , 8 ), (8 , 8 )]
211228
@@ -216,7 +233,7 @@ def prepare_dummy_input(self, height, width):
216233 return Flux2TransformerTests ().prepare_dummy_input (height = height , width = width )
217234
218235
219- class FluxTransformerLoRAHotSwapTests (LoraHotSwappingForModelTesterMixin , unittest .TestCase ):
236+ class Flux2TransformerLoRAHotSwapTests (LoraHotSwappingForModelTesterMixin , unittest .TestCase ):
220237 model_class = Flux2Transformer2DModel
221238 different_shapes_for_compilation = [(4 , 4 ), (4 , 8 ), (8 , 8 )]
222239
0 commit comments