@@ -549,8 +549,11 @@ def prepare_latents(
549549 image_latents = torch .cat ([image_latents ], dim = 0 )
550550
551551 noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
552- latents = self .scheduler .scale_noise (image_latents , timestep , noise )
552+ import numpy as np
553+ sigma = timestep [0 ]
554+ latents = sigma * noise + (1.0 - sigma ) * image_latents
553555 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
556+ np .save ("reference_image_latent.npy" , latents .detach ().cpu ().float ().numpy ())
554557 return latents , latent_image_ids
555558
556559 @property
@@ -569,6 +572,35 @@ def num_timesteps(self):
569572 def interrupt (self ):
570573 return self ._interrupt
571574
575+ def enable_vae_slicing (self ):
576+ r"""
577+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
578+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
579+ """
580+ self .vae .enable_slicing ()
581+
582+ def disable_vae_slicing (self ):
583+ r"""
584+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
585+ computing decoding in one step.
586+ """
587+ self .vae .disable_slicing ()
588+
589+ def enable_vae_tiling (self ):
590+ r"""
591+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
592+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
593+ processing larger images.
594+ """
595+ self .vae .enable_tiling ()
596+
597+ def disable_vae_tiling (self ):
598+ r"""
599+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
600+ computing decoding in one step.
601+ """
602+ self .vae .disable_tiling ()
603+
572604 @torch .no_grad ()
573605 @replace_example_docstring (EXAMPLE_DOC_STRING )
574606 def __call__ (
@@ -582,7 +614,8 @@ def __call__(
582614 num_inference_steps : int = 28 ,
583615 timesteps : List [int ] = None ,
584616 guidance_scale : float = 7.0 ,
585- controller_guidance : float = 5.0 ,
617+ controller_guidance : float = 0.5 ,
618+ stop_step : int = 0 ,
586619 num_images_per_prompt : Optional [int ] = 1 ,
587620 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
588621 latents : Optional [torch .FloatTensor ] = None ,
@@ -728,26 +761,23 @@ def __call__(
728761 max_sequence_length = max_sequence_length ,
729762 lora_scale = lora_scale ,
730763 )
764+ import math
765+ def time_shift (mu : float , sigma : float , t : torch .Tensor ):
766+ return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
767+
731768
732- # 4.Prepare timesteps
733- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
734769 image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
735- mu = calculate_shift (
736- image_seq_len ,
737- self .scheduler .config .base_image_seq_len ,
738- self .scheduler .config .max_image_seq_len ,
739- self .scheduler .config .base_shift ,
740- self .scheduler .config .max_shift ,
741- )
742- timesteps , num_inference_steps = retrieve_timesteps (
743- self .scheduler ,
744- num_inference_steps ,
745- device ,
746- timesteps ,
747- sigmas ,
748- mu = mu ,
749- )
750- timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
770+ def get_lin_function (
771+ x1 : float = 256 , y1 : float = 0.5 , x2 : float = 4096 , y2 : float = 1.15
772+ ) -> Callable [[float ], float ]:
773+ m = (y2 - y1 ) / (x2 - x1 )
774+ b = y1 - m * x1
775+ return lambda x : m * x + b
776+
777+ mu = get_lin_function ()(image_seq_len )
778+ timesteps = torch .linspace (0 , 1 , num_inference_steps + 1 )
779+ timesteps = time_shift (mu , 1.0 , timesteps ).to ("cuda" , torch .bfloat16 )
780+ # 4.Prepare timesteps
751781
752782 if num_inference_steps < 1 :
753783 raise ValueError (
@@ -758,10 +788,9 @@ def __call__(
758788
759789 # 5. Prepare latent variables
760790 num_channels_latents = self .transformer .config .in_channels // 4
761-
762791 latents , latent_image_ids = self .prepare_latents (
763792 init_image ,
764- latent_timestep ,
793+ timesteps ,
765794 batch_size * num_images_per_prompt ,
766795 num_channels_latents ,
767796 height ,
@@ -784,18 +813,17 @@ def __call__(
784813
785814 # fix noise sample y1
786815 y1 = randn_tensor (latents .shape , generator = generator , device = device , dtype = latents .dtype )
787-
788816 # 6. Denoising loop
789817 with self .progress_bar (total = num_inference_steps ) as progress_bar :
790- for i , t in enumerate ( timesteps ):
818+ for i in range ( num_inference_steps - stop_step ):
791819 if self .interrupt :
792820 continue
793-
821+ t = torch . tensor ([ timesteps [ i ]], device = latents . device , dtype = latents . dtype )
794822 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
795823 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
796824 noise_pred = self .transformer (
797825 hidden_states = latents ,
798- timestep = timestep / 1000 ,
826+ timestep = timestep ,
799827 guidance = guidance ,
800828 pooled_projections = pooled_prompt_embeds ,
801829 encoder_hidden_states = prompt_embeds ,
@@ -806,30 +834,13 @@ def __call__(
806834 )[0 ]
807835
808836 unconditional_vector_field = noise_pred
809- conditional_vector_field = (y1 - unconditional_vector_field )/ (1 - ( timestep / 1000 ) )
837+ conditional_vector_field = (y1 - latents )/ (1 - timestep )
810838 controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field )
811839
812- # compute the previous noisy sample x_t -> x_t-1
813- latents_dtype = latents .dtype
814- latents = self .scheduler .step (controlled_vector_field , t , latents , return_dict = False )[0 ]
815-
816- if latents .dtype != latents_dtype :
817- if torch .backends .mps .is_available ():
818- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
819- latents = latents .to (latents_dtype )
820-
821- if callback_on_step_end is not None :
822- callback_kwargs = {}
823- for k in callback_on_step_end_tensor_inputs :
824- callback_kwargs [k ] = locals ()[k ]
825- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
826-
827- latents = callback_outputs .pop ("latents" , latents )
828- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
829-
830- # call the callback, if provided
831- if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
832- progress_bar .update ()
840+ # Get the corresponding sigma values
841+ sigma = timesteps [i ]
842+ sigma_next = timesteps [i + 1 ]
843+ latents = latents + (sigma_next - sigma ) * controlled_vector_field
833844
834845 if XLA_AVAILABLE :
835846 xm .mark_step ()
0 commit comments