2121from  diffusers .image_processor  import  VaeImageProcessor 
2222from  diffusers .models  import  AuraFlowTransformer2DModel , AutoencoderKL 
2323from  diffusers .models .attention_processor  import  AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor 
24+ from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline , ImagePipelineOutput 
2425from  diffusers .schedulers  import  FlowMatchEulerDiscreteScheduler 
2526from  diffusers .utils  import  is_torch_xla_available , logging , replace_example_docstring 
2627from  diffusers .utils .torch_utils  import  randn_tensor 
27- from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline , ImagePipelineOutput 
2828
2929
3030if  is_torch_xla_available ():
@@ -119,12 +119,12 @@ def check_inputs(
119119    ):
120120        if  strength  <  0  or  strength  >  1 :
121121            raise  ValueError (f"The value of strength should be in [0.0, 1.0] but is { strength }  )
122-              
122+ 
123123        patch_size  =  2   # AuraFlow uses patch size of 2 
124124        required_divisor  =  self .vae_scale_factor  *  patch_size 
125125        if  height  %  required_divisor  !=  0  or  width  %  required_divisor  !=  0 :
126126            raise  ValueError (
127-                 f "\`height\` and \`width\` have to be divisible by the VAE scale factor ({ self .vae_scale_factor } { patch_size } { required_divisor } 
127+                 rf "\`height\` and \`width\` have to be divisible by the VAE scale factor ({ self .vae_scale_factor } { patch_size } { required_divisor } 
128128                f"Your dimensions are ({ height } { width }  
129129            )
130130
@@ -339,7 +339,7 @@ def prepare_latents(
339339
340340    def  get_timesteps (self , num_inference_steps , strength , device ):
341341        # 1. Call set_timesteps with num_inference_steps 
342-         self .scheduler .set_timesteps (num_inference_steps , device = device )  # Ensure scheduler uses the correct number of steps 
342+         self .scheduler .set_timesteps (num_inference_steps , device = device )
343343
344344        # 2. Calculate strength-based number of steps and offset 
345345        init_timestep_count  =  min (int (num_inference_steps  *  strength ), num_inference_steps )
@@ -353,14 +353,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
353353        return  timesteps , num_actual_inference_steps 
354354
355355    def  prepare_img2img_latents (
356-         self , 
357-         image , 
358-         timestep , 
359-         batch_size , 
360-         num_images_per_prompt , 
361-         dtype , 
362-         device , 
363-         generator = None 
356+         self , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None 
364357    ):
365358        if  not  isinstance (image , (torch .Tensor , PIL .Image .Image , list )):
366359            raise  ValueError (
@@ -380,34 +373,87 @@ def prepare_img2img_latents(
380373                    f" size of { batch_size }  
381374                )
382375
383-             if  image .shape [0 ] ==  1 :
384-                 image  =  image .repeat (batch_size , 1 , 1 , 1 )
376+             # Handle different batch size scenarios 
377+             if  image .shape [0 ] <  batch_size :
378+                 if  batch_size  %  image .shape [0 ] ==  0 :
379+                     # Duplicate the image to match the batch size 
380+                     additional_image_per_prompt  =  batch_size  //  image .shape [0 ]
381+                     image  =  torch .cat ([image ] *  additional_image_per_prompt , dim = 0 )
382+                 else :
383+                     raise  ValueError (
384+                         f"Cannot duplicate `image` of batch size { image .shape [0 ]} { batch_size }  
385+                         f" Batch size must be divisible by the image batch size." 
386+                     )
385387
386388            # encode the init image into latents and scale the latents 
387-             latents  =  self .vae .encode (image ).latent_dist .sample (generator = generator )
389+             # 1. Get VAE distribution parameters (on device) 
390+             latent_dist  =  self .vae .encode (image ).latent_dist 
391+             mean , std  =  latent_dist .mean , latent_dist .std   # Already on device 
392+ 
393+             # 2. Sample noise for each batch element individually if using multiple generators 
394+             if  isinstance (generator , list ):
395+                 sample  =  torch .cat (
396+                     [
397+                         torch .randn (
398+                             (1 , * mean .shape [1 :]),
399+                             generator = generator [i ],
400+                             device = generator [i ].device  if  hasattr (generator [i ], "device" ) else  "cpu" ,
401+                             dtype = mean .dtype ,
402+                         ).to (mean .device )
403+                         for  i  in  range (batch_size )
404+                     ]
405+                 )
406+             else :
407+                 # Single generator - use its device if it has one 
408+                 generator_device  =  getattr (generator , "device" , "cpu" ) if  generator  is  not None  else  "cpu" 
409+                 noise  =  torch .randn (mean .shape , generator = generator , device = generator_device , dtype = mean .dtype )
410+                 sample  =  noise .to (mean .device )
411+ 
412+             # Compute latents 
413+             latents  =  mean  +  std  *  sample 
414+ 
415+             # Scale latents 
388416            latents  =  latents  *  self .vae .config .scaling_factor 
389417
390418            # get the original timestep using init_timestep 
391419            init_timestep  =  timestep 
392420
393421            # add noise to latents using the timesteps 
394-             noise  =  torch .randn (latents .shape , generator = generator , device = device , dtype = dtype )
395-             
422+             # Handle noise generation with multiple generators if provided 
423+             if  isinstance (generator , list ):
424+                 noise  =  torch .cat (
425+                     [
426+                         torch .randn (
427+                             (1 , * latents .shape [1 :]),
428+                             generator = generator [i ],
429+                             device = generator [i ].device  if  hasattr (generator [i ], "device" ) else  "cpu" ,
430+                             dtype = latents .dtype ,
431+                         ).to (latents .device )
432+                         for  i  in  range (batch_size )
433+                     ]
434+                 )
435+             else :
436+                 # Single generator - use its device if it has one 
437+                 generator_device  =  getattr (generator , "device" , "cpu" ) if  generator  is  not None  else  "cpu" 
438+                 noise  =  torch .randn (
439+                     latents .shape , generator = generator , device = generator_device , dtype = latents .dtype 
440+                 ).to (latents .device )
441+ 
396442            # Ensure timestep tensor is on the same device 
397443            t  =  init_timestep .to (latents .device )
398-              
444+ 
399445            # Normalize timestep to [0, 1] range (using scheduler's config) 
400446            t  =  t  /  self .scheduler .config .num_train_timesteps 
401-              
447+ 
402448            # Reshape t to match the dimensions needed for broadcasting 
403449            required_dims  =  len (latents .shape )
404450            current_dims  =  len (t .shape )
405451            for  _  in  range (required_dims  -  current_dims ):
406452                t  =  t .unsqueeze (- 1 )
407-              
453+ 
408454            # Interpolation: x_t = t * x_1 + (1 - t) * x_0 
409455            latents  =  t  *  noise  +  (1  -  t ) *  latents 
410-              
456+ 
411457        return  latents 
412458
413459    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae 
@@ -606,13 +652,14 @@ def __call__(
606652            negative_prompt_attention_mask = negative_prompt_attention_mask ,
607653            max_sequence_length = max_sequence_length ,
608654        )
655+ 
609656        if  do_classifier_free_guidance :
610657            prompt_embeds  =  torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
611658
612659        # 5. Prepare timesteps 
613660        timesteps , num_inference_steps  =  self .get_timesteps (num_inference_steps , strength , device )
614661        latent_timestep  =  timesteps [:1 ]
615-          
662+ 
616663        # 6. Prepare latent variables 
617664        latents  =  self .prepare_img2img_latents (
618665            image ,
@@ -632,10 +679,13 @@ def __call__(
632679                # expand the latents if we are doing classifier free guidance 
633680                latent_model_input  =  torch .cat ([latents ] *  2 ) if  do_classifier_free_guidance  else  latents 
634681
635-                 # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image 
636-                 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
637-                 timestep  =  torch .tensor ([t  /  1000 ]).expand (latent_model_input .shape [0 ])
638-                 timestep  =  timestep .to (latents .device , dtype = latents .dtype )
682+                 # AureFlow use timestep value between 0 and 1, with t=1 as noise and t=0 as the image 
683+                 # create a timestep tensor with the correct batch size 
684+                 # ensure it matches the batch size of the model input 
685+                 t_float  =  t  /  1000 
686+                 timestep_tensor  =  torch .full (
687+                     (latent_model_input .shape [0 ],), t_float , device = latents .device , dtype = latents .dtype 
688+                 )
639689
640690                # Make sure latent_model_input has the same dtype as the transformer 
641691                transformer_dtype  =  self .transformer .dtype 
@@ -646,7 +696,7 @@ def __call__(
646696                noise_pred  =  self .transformer (
647697                    latent_model_input ,
648698                    encoder_hidden_states = prompt_embeds ,
649-                     timestep = timestep ,
699+                     timestep = timestep_tensor ,
650700                    return_dict = False ,
651701                )[0 ]
652702
@@ -682,15 +732,19 @@ def __call__(
682732            if  needs_upcasting :
683733                self .upcast_vae ()
684734                latents  =  latents .to (next (iter (self .vae .post_quant_conv .parameters ())).dtype )
685-              
735+ 
686736            # Apply proper scaling factor and shift factor if available 
687-             if  hasattr (self .vae .config , "scaling_factor" ) and  hasattr (self .vae .config , "shift_factor" ) and  getattr (self .vae .config , "shift_factor" , None ) is  not None :
737+             if  (
738+                 hasattr (self .vae .config , "scaling_factor" )
739+                 and  hasattr (self .vae .config , "shift_factor" )
740+                 and  getattr (self .vae .config , "shift_factor" , None ) is  not None 
741+             ):
688742                # Handle both scaling and shifting 
689743                latents  =  (latents  /  self .vae .config .scaling_factor ) +  self .vae .config .shift_factor 
690744            else :
691745                # Just scale using standard approach 
692746                latents  =  latents  /  self .vae .config .scaling_factor 
693-                  
747+ 
694748            image  =  self .vae .decode (latents , return_dict = False )[0 ]
695749            image  =  self .image_processor .postprocess (image , output_type = output_type )
696750
@@ -700,4 +754,4 @@ def __call__(
700754        if  not  return_dict :
701755            return  (image ,)
702756
703-         return  ImagePipelineOutput (images = image )  
757+         return  ImagePipelineOutput (images = image )
0 commit comments