@@ -345,15 +345,44 @@ def test_keep_modules_in_fp32(self):
345345 SD3Transformer2DModel ._keep_in_fp32_modules = ["proj_out" ]
346346
347347 model = SD3Transformer2DModel .from_pretrained (
348- "stabilityai/stable-diffusion-3-medium-diffusers " , subfolder = "transformer" , torch_dtype = torch_dtype
349- )
348+ "hf-internal-testing/tiny-sd3-pipe " , subfolder = "transformer" , torch_dtype = torch_dtype
349+ ). to ( "cuda" )
350350
351351 for name , module in model .named_modules ():
352352 if isinstance (module , torch .nn .Linear ):
353353 if name in model ._keep_in_fp32_modules :
354354 self .assertTrue (module .weight .dtype == torch .float32 )
355355 else :
356356 self .assertTrue (module .weight .dtype == torch_dtype )
357+
358+ def get_dummy_inputs ():
359+ batch_size = 2
360+ num_channels = 4
361+ height = width = embedding_dim = 32
362+ pooled_embedding_dim = embedding_dim * 2
363+ sequence_length = 154
364+
365+ hidden_states = torch .randn ((batch_size , num_channels , height , width )).to (torch_device )
366+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
367+ pooled_prompt_embeds = torch .randn ((batch_size , pooled_embedding_dim )).to (torch_device )
368+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
369+
370+ return {
371+ "hidden_states" : hidden_states ,
372+ "encoder_hidden_states" : encoder_hidden_states ,
373+ "pooled_projections" : pooled_prompt_embeds ,
374+ "timestep" : timestep ,
375+ }
376+
377+ # test if inference works.
378+ with torch .no_grad () and torch .amp .autocast ("cuda" , dtype = torch_dtype ):
379+ input_dict_for_transformer = get_dummy_inputs ()
380+ model_inputs = {
381+ k : v .to (device = torch_device ) for k , v in input_dict_for_transformer .items () if not isinstance (v , bool )
382+ }
383+ model_inputs .update ({k : v for k , v in input_dict_for_transformer .items () if k not in model_inputs })
384+ _ = model (** model_inputs )
385+
357386 SD3Transformer2DModel ._keep_in_fp32_modules = fp32_modules
358387
359388
0 commit comments