5252        >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") 
5353        >>> init_image = init_image.resize((768, 512)) 
5454
55-         >>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) 
55+         >>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow-v0.3 ", torch_dtype=torch.float16) 
5656        >>> pipe = pipe.to("cuda") 
5757        >>> prompt = "A fantasy landscape, trending on artstation" 
5858        >>> image = pipe(prompt=prompt, image=init_image, strength=0.75, num_inference_steps=50).images[0] 
@@ -338,19 +338,20 @@ def prepare_latents(
338338        return  latents 
339339
340340    def  get_timesteps (self , num_inference_steps , strength , device ):
341-         # 1. Call set_timesteps with num_inference_steps  
341+         # Set timesteps using the full range initially  
342342        self .scheduler .set_timesteps (num_inference_steps , device = device )
343+         timesteps  =  self .scheduler .timesteps .to (device = device )
343344
344-         # 2. Calculate strength-based number of steps and offset 
345-         init_timestep_count  =  min (int (num_inference_steps  *  strength ), num_inference_steps )
346-         t_start  =  max (num_inference_steps  -  init_timestep_count , 0 )
345+         if  len (timesteps ) !=  num_inference_steps :
346+             num_inference_steps  =  len (timesteps )  # Adjust if scheduler changed num_steps 
347347
348-         # 3. Get the timesteps *after* set_timesteps has been called (now has length num_inference_steps) 
348+         # Get the original timestep using init_timestep 
349+         init_timestep  =  min (num_inference_steps  *  strength , num_inference_steps )
350+ 
351+         t_start  =  int (max (num_inference_steps  -  init_timestep , 0 ))
349352        timesteps  =  self .scheduler .timesteps [t_start :]
350353
351-         # 4. Return the correct slice and the number of actual steps 
352-         num_actual_inference_steps  =  len (timesteps )
353-         return  timesteps , num_actual_inference_steps 
354+         return  timesteps , num_inference_steps  -  t_start 
354355
355356    def  prepare_img2img_latents (
356357        self , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None 
@@ -385,11 +386,20 @@ def prepare_img2img_latents(
385386                        f" Batch size must be divisible by the image batch size." 
386387                    )
387388
389+             # Temporarily move VAE to float32 for encoding 
390+             vae_dtype  =  self .vae .dtype 
391+             if  vae_dtype  !=  torch .float32 :
392+                 self .vae .to (dtype = torch .float32 )
393+ 
388394            # encode the init image into latents and scale the latents 
389395            # 1. Get VAE distribution parameters (on device) 
390-             latent_dist  =  self .vae .encode (image ).latent_dist 
396+             latent_dist  =  self .vae .encode (image . to ( dtype = torch . float32 ) ).latent_dist 
391397            mean , std  =  latent_dist .mean , latent_dist .std   # Already on device 
392398
399+             # Restore VAE dtype 
400+             if  vae_dtype  !=  torch .float32 :
401+                 self .vae .to (dtype = vae_dtype )
402+ 
393403            # 2. Sample noise for each batch element individually if using multiple generators 
394404            if  isinstance (generator , list ):
395405                sample  =  torch .cat (
@@ -416,7 +426,7 @@ def prepare_img2img_latents(
416426            latents  =  latents  *  self .vae .config .scaling_factor 
417427
418428            # get the original timestep using init_timestep 
419-             init_timestep  =  timestep 
429+             init_timestep  =  timestep   # Use the passed timestep directly 
420430
421431            # add noise to latents using the timesteps 
422432            # Handle noise generation with multiple generators if provided 
@@ -439,20 +449,7 @@ def prepare_img2img_latents(
439449                    latents .shape , generator = generator , device = generator_device , dtype = latents .dtype 
440450                ).to (latents .device )
441451
442-             # Ensure timestep tensor is on the same device 
443-             t  =  init_timestep .to (latents .device )
444- 
445-             # Normalize timestep to [0, 1] range (using scheduler's config) 
446-             t  =  t  /  self .scheduler .config .num_train_timesteps 
447- 
448-             # Reshape t to match the dimensions needed for broadcasting 
449-             required_dims  =  len (latents .shape )
450-             current_dims  =  len (t .shape )
451-             for  _  in  range (required_dims  -  current_dims ):
452-                 t  =  t .unsqueeze (- 1 )
453- 
454-             # Interpolation: x_t = t * x_1 + (1 - t) * x_0 
455-             latents  =  t  *  noise  +  (1  -  t ) *  latents 
452+             latents  =  self .scheduler .scale_noise (latents , init_timestep , noise )
456453
457454        return  latents 
458455
@@ -657,8 +654,10 @@ def __call__(
657654            prompt_embeds  =  torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
658655
659656        # 5. Prepare timesteps 
660-         timesteps , num_inference_steps  =  self .get_timesteps (num_inference_steps , strength , device )
661-         latent_timestep  =  timesteps [:1 ]
657+         timesteps , num_inference_steps  =  self .get_timesteps (
658+             num_inference_steps , strength , device 
659+         )
660+         latent_timestep  =  timesteps [:1 ].repeat (batch_size  *  num_images_per_prompt ) # Get the first timestep(s) for initial noise 
662661
663662        # 6. Prepare latent variables 
664663        latents  =  self .prepare_img2img_latents (
@@ -727,11 +726,11 @@ def __call__(
727726        if  output_type  ==  "latent" :
728727            image  =  latents 
729728        else :
730-             # make sure the  VAE is in  float32 mode, as it overflows in float16  
731-             needs_upcasting  =  self .vae .dtype   ==   torch . float16   and   self . vae . config . force_upcast 
732-             if  needs_upcasting :
733-                 self .upcast_vae ( )
734-                 latents  =  latents .to (next ( iter ( self . vae . post_quant_conv . parameters ())). dtype )
729+             # Always upcast  VAE to  float32 for decoding  
730+             vae_dtype  =  self .vae .dtype 
731+             if  vae_dtype   !=   torch . float32 :
732+                 self .vae . to ( dtype = torch . float32 )
733+                 latents  =  latents .to (dtype = torch . float32 )
735734
736735            # Apply proper scaling factor and shift factor if available 
737736            if  (
@@ -746,6 +745,11 @@ def __call__(
746745                latents  =  latents  /  self .vae .config .scaling_factor 
747746
748747            image  =  self .vae .decode (latents , return_dict = False )[0 ]
748+ 
749+             # Restore VAE dtype 
750+             if  vae_dtype  !=  torch .float32 :
751+                 self .vae .to (dtype = vae_dtype )
752+ 
749753            image  =  self .image_processor .postprocess (image , output_type = output_type )
750754
751755        # Offload all models 
0 commit comments