@@ -184,12 +184,8 @@ def log_validation(
184184 for _ in range (args .num_validation_images ):
185185 with autocast_ctx :
186186 image = pipeline (
187- prompt_embeds_t5 = pipeline_args ["prompt_embeds_t5" ],
188- prompt_embeds_llama3 = pipeline_args ["prompt_embeds_llama3" ],
189- negative_prompt_embeds_t5 = pipeline_args ["negative_prompt_embeds_t5" ],
190- negative_prompt_embeds_llama3 = pipeline_args ["negative_prompt_embeds_llama3" ],
191- pooled_prompt_embeds = pipeline_args ["pooled_prompt_embeds" ],
192- negative_pooled_prompt_embeds = pipeline_args ["negative_pooled_prompt_embeds" ],
187+ prompt_embeds = pipeline_args ["prompt_embeds" ],
188+ prompt_embeds_mask = pipeline_args ["prompt_embeds_mask" ],
193189 generator = generator ,
194190 ).images [0 ]
195191 images .append (image )
@@ -476,7 +472,7 @@ def parse_args(input_args=None):
476472 parser .add_argument (
477473 "--guidance_scale" ,
478474 type = float ,
479- default = 3.5 ,
475+ default = 0.0 ,
480476 help = "Qwen image is a guidance distilled model" ,
481477 )
482478 parser .add_argument (
@@ -1495,18 +1491,20 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14951491 img_shapes = [
14961492 (1 , args .resolution // vae_scale_factor // 2 , args .resolution // vae_scale_factor // 2 )
14971493 ] * bsz
1494+ # transpose the dimensions
1495+ noisy_model_input = noisy_model_input .permute (0 , 2 , 1 , 3 , 4 )
14981496 packed_noisy_model_input = QwenImagePipeline ._pack_latents (
14991497 noisy_model_input ,
15001498 batch_size = model_input .shape [0 ],
15011499 num_channels_latents = model_input .shape [1 ],
1502- height = model_input .shape [2 ],
1503- width = model_input .shape [3 ],
1500+ height = model_input .shape [3 ],
1501+ width = model_input .shape [4 ],
15041502 )
15051503 model_pred = transformer (
15061504 hidden_states = packed_noisy_model_input ,
1507- encoder_hidden_states_t5 = prompt_embeds ,
1505+ encoder_hidden_states = prompt_embeds ,
15081506 encoder_hidden_states_mask = prompt_embeds_mask ,
1509- timesteps = timesteps / 1000 ,
1507+ timestep = timesteps / 1000 ,
15101508 guidance = guidance ,
15111509 img_shapes = img_shapes ,
15121510 txt_seq_lens = prompt_embeds_mask .sum (dim = 1 ).tolist (),
0 commit comments