@@ -252,9 +252,20 @@ def _get_qwen_prompt_embeds(
252252 drop_idx = self .prompt_template_encode_start_idx
253253 txt = [template .format (base_img_prompt + e ) for e in prompt ]
254254
255+ if image is None :
256+ images_for_processor = None
257+ else :
258+ # If `image` is a single image (not list) the processor will broadcast it.
259+ # If `image` is a list of conditioning images, we must repeat that list
260+ # for each prompt so processor has one entry per text example.
261+ if isinstance (image , list ):
262+ images_for_processor = [image ] * len (txt )
263+ else :
264+ images_for_processor = image
265+
255266 model_inputs = self .processor (
256267 text = txt ,
257- images = image ,
268+ images = images_for_processor ,
258269 padding = True ,
259270 return_tensors = "pt" ,
260271 ).to (device )
@@ -627,7 +638,12 @@ def __call__(
627638 [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
628639 returning a tuple, the first element is a list with the generated images.
629640 """
630- image_size = image [- 1 ].size if isinstance (image , list ) else image .size
641+ # Use the first image's size as the deterministic base for output dims
642+ ref_img = image [0 ] if isinstance (image , list ) else image
643+ if isinstance (ref_img , (tuple , list )):
644+ ref_img = ref_img [0 ]
645+ image_size = ref_img .size
646+
631647 calculated_width , calculated_height = calculate_dimensions (1024 * 1024 , image_size [0 ] / image_size [1 ])
632648 height = height or calculated_height
633649 width = width or calculated_width
@@ -673,6 +689,7 @@ def __call__(
673689 vae_image_sizes = []
674690 vae_images = []
675691 for img in image :
692+ img = img [0 ] if isinstance (img , (tuple , list )) else img
676693 image_width , image_height = img .size
677694 condition_width , condition_height = calculate_dimensions (
678695 CONDITION_IMAGE_SIZE , image_width / image_height
@@ -681,7 +698,10 @@ def __call__(
681698 condition_image_sizes .append ((condition_width , condition_height ))
682699 vae_image_sizes .append ((vae_width , vae_height ))
683700 condition_images .append (self .image_processor .resize (img , condition_height , condition_width ))
684- vae_images .append (self .image_processor .preprocess (img , vae_height , vae_width ).unsqueeze (2 ))
701+ preproc = self .image_processor .preprocess (img , vae_height , vae_width )
702+ if isinstance (preproc , (tuple , list )):
703+ preproc = preproc [0 ]
704+ vae_images .append (preproc .unsqueeze (0 ))
685705
686706 has_neg_prompt = negative_prompt is not None or (
687707 negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
@@ -719,6 +739,25 @@ def __call__(
719739
720740 # 4. Prepare latent variables
721741 num_channels_latents = self .transformer .config .in_channels // 4
742+ if vae_images is not None :
743+ for idx , v in enumerate (vae_images ):
744+ if isinstance (v , (tuple , list )):
745+ v = v [0 ]
746+
747+ if not torch .is_tensor (v ):
748+ v = torch .as_tensor (v )
749+
750+ if v .ndim == 5 and v .shape [1 ] == 1 and v .shape [2 ] in (1 , 3 ):
751+ v = v .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
752+
753+ elif v .ndim == 4 and v .shape [1 ] in (1 , 3 ):
754+ v = v .unsqueeze (2 )
755+
756+ elif v .ndim == 3 and v .shape [0 ] in (1 , 3 ):
757+ v = v .unsqueeze (0 ).unsqueeze (2 )
758+
759+ vae_images [idx ] = v
760+
722761 latents , image_latents = self .prepare_latents (
723762 vae_images ,
724763 batch_size * num_images_per_prompt ,
@@ -730,15 +769,12 @@ def __call__(
730769 generator ,
731770 latents ,
732771 )
733- img_shapes = [
734- [
735- (1 , height // self .vae_scale_factor // 2 , width // self .vae_scale_factor // 2 ),
736- * [
737- (1 , vae_height // self .vae_scale_factor // 2 , vae_width // self .vae_scale_factor // 2 )
738- for vae_width , vae_height in vae_image_sizes
739- ],
740- ]
741- ] * batch_size
772+ base_shape = (1 , height // self .vae_scale_factor // 2 , width // self .vae_scale_factor // 2 )
773+ per_image_shapes = [
774+ (1 , vae_height // self .vae_scale_factor // 2 , vae_width // self .vae_scale_factor // 2 )
775+ for vae_width , vae_height in vae_image_sizes
776+ ]
777+ img_shapes = [[base_shape , * per_image_shapes ] for _ in range (batch_size )]
742778
743779 # 5. Prepare timesteps
744780 sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
0 commit comments