1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ from typing import Any
17+
1618import torch
1719
1820from diffusers import FluxTransformer2DModel
@@ -46,7 +48,11 @@ class FluxTransformerTesterConfig:
4648 pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
4749 pretrained_model_kwargs = {"subfolder" : "transformer" }
4850
49- def get_init_dict (self ):
51+ @property
52+ def generator (self ):
53+ return torch .Generator ("cpu" ).manual_seed (0 )
54+
55+ def get_init_dict (self ) -> dict [str , int | list [int ]]:
5056 """Return Flux model initialization arguments."""
5157 return {
5258 "patch_size" : 1 ,
@@ -60,30 +66,32 @@ def get_init_dict(self):
6066 "axes_dims_rope" : [4 , 4 , 8 ],
6167 }
6268
63- def get_dummy_inputs (self ):
69+ def get_dummy_inputs (self ) -> dict [ str , torch . Tensor ] :
6470 batch_size = 1
6571 height = width = 4
6672 num_latent_channels = 4
6773 num_image_channels = 3
68- sequence_length = 24
69- embedding_dim = 8
74+ sequence_length = 48
75+ embedding_dim = 32
7076
7177 return {
72- "hidden_states" : randn_tensor ((batch_size , height * width , num_latent_channels )),
73- "encoder_hidden_states" : randn_tensor ((batch_size , sequence_length , embedding_dim )),
74- "pooled_projections" : randn_tensor ((batch_size , embedding_dim )),
75- "img_ids" : randn_tensor ((height * width , num_image_channels )),
76- "txt_ids" : randn_tensor ((sequence_length , num_image_channels )),
78+ "hidden_states" : randn_tensor ((batch_size , height * width , num_latent_channels ), generator = self .generator ),
79+ "encoder_hidden_states" : randn_tensor (
80+ (batch_size , sequence_length , embedding_dim ), generator = self .generator
81+ ),
82+ "pooled_projections" : randn_tensor ((batch_size , embedding_dim ), generator = self .generator ),
83+ "img_ids" : randn_tensor ((height * width , num_image_channels ), generator = self .generator ),
84+ "txt_ids" : randn_tensor ((sequence_length , num_image_channels ), generator = self .generator ),
7785 "timestep" : torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size ),
7886 }
7987
8088 @property
81- def input_shape (self ):
82- return (16 , 4 )
89+ def input_shape (self ) -> tuple [ int , int ] :
90+ return (1 , 16 , 4 )
8391
8492 @property
85- def output_shape (self ):
86- return (16 , 4 )
93+ def output_shape (self ) -> tuple [ int , int ] :
94+ return (1 , 16 , 4 )
8795
8896
8997class TestFluxTransformer (FluxTransformerTesterConfig , ModelTesterMixin ):
@@ -140,7 +148,7 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
140148class TestFluxTransformerIPAdapter (FluxTransformerTesterConfig , IPAdapterTesterMixin ):
141149 """IP Adapter tests for Flux Transformer."""
142150
143- def create_ip_adapter_state_dict (self , model ) :
151+ def create_ip_adapter_state_dict (self , model : Any ) -> dict [ str , dict [ str , Any ]] :
144152 from diffusers .models .transformers .transformer_flux import FluxIPAdapterAttnProcessor
145153
146154 ip_cross_attn_state_dict = {}
@@ -202,7 +210,7 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin
202210
203211 different_shapes_for_compilation = [(4 , 4 ), (4 , 8 ), (8 , 8 )]
204212
205- def get_dummy_inputs (self , height = 4 , width = 4 ) :
213+ def get_dummy_inputs (self , height : int = 4 , width : int = 4 ) -> dict [ str , torch . Tensor ] :
206214 """Override to support dynamic height/width for LoRA hotswap tests."""
207215 batch_size = 1
208216 num_latent_channels = 4
@@ -223,7 +231,7 @@ def get_dummy_inputs(self, height=4, width=4):
223231class TestFluxTransformerCompile (FluxTransformerTesterConfig , TorchCompileTesterMixin ):
224232 different_shapes_for_compilation = [(4 , 4 ), (4 , 8 ), (8 , 8 )]
225233
226- def get_dummy_inputs (self , height = 4 , width = 4 ) :
234+ def get_dummy_inputs (self , height : int = 4 , width : int = 4 ) -> dict [ str , torch . Tensor ] :
227235 """Override to support dynamic height/width for compilation tests."""
228236 batch_size = 1
229237 num_latent_channels = 4
@@ -250,7 +258,7 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
250258
251259
252260class TestFluxTransformerBitsAndBytes (FluxTransformerTesterConfig , BitsAndBytesTesterMixin ):
253- def get_dummy_inputs (self ):
261+ def get_dummy_inputs (self ) -> dict [ str , torch . Tensor ] :
254262 return {
255263 "hidden_states" : randn_tensor ((1 , 4096 , 64 )),
256264 "encoder_hidden_states" : randn_tensor ((1 , 512 , 4096 )),
@@ -263,7 +271,7 @@ def get_dummy_inputs(self):
263271
264272
265273class TestFluxTransformerQuanto (FluxTransformerTesterConfig , QuantoTesterMixin ):
266- def get_dummy_inputs (self ):
274+ def get_dummy_inputs (self ) -> dict [ str , torch . Tensor ] :
267275 return {
268276 "hidden_states" : randn_tensor ((1 , 4096 , 64 )),
269277 "encoder_hidden_states" : randn_tensor ((1 , 512 , 4096 )),
@@ -276,7 +284,7 @@ def get_dummy_inputs(self):
276284
277285
278286class TestFluxTransformerTorchAo (FluxTransformerTesterConfig , TorchAoTesterMixin ):
279- def get_dummy_inputs (self ):
287+ def get_dummy_inputs (self ) -> dict [ str , torch . Tensor ] :
280288 return {
281289 "hidden_states" : randn_tensor ((1 , 4096 , 64 )),
282290 "encoder_hidden_states" : randn_tensor ((1 , 512 , 4096 )),
@@ -291,7 +299,7 @@ def get_dummy_inputs(self):
291299class TestFluxTransformerGGUF (FluxTransformerTesterConfig , GGUFTesterMixin ):
292300 gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
293301
294- def get_dummy_inputs (self ):
302+ def get_dummy_inputs (self ) -> dict [ str , torch . Tensor ] :
295303 return {
296304 "hidden_states" : randn_tensor ((1 , 4096 , 64 )),
297305 "encoder_hidden_states" : randn_tensor ((1 , 512 , 4096 )),
@@ -304,7 +312,7 @@ def get_dummy_inputs(self):
304312
305313
306314class TestFluxTransformerModelOpt (FluxTransformerTesterConfig , ModelOptTesterMixin ):
307- def get_dummy_inputs (self ):
315+ def get_dummy_inputs (self ) -> dict [ str , torch . Tensor ] :
308316 return {
309317 "hidden_states" : randn_tensor ((1 , 4096 , 64 )),
310318 "encoder_hidden_states" : randn_tensor ((1 , 512 , 4096 )),
0 commit comments