2323from transformers import Gemma2PreTrainedModel , GemmaTokenizer , GemmaTokenizerFast
2424
2525from ...callbacks import MultiPipelineCallbacks , PipelineCallback
26- from ...image_processor import PixArtImageProcessor , PipelineImageInput
26+ from ...image_processor import PipelineImageInput , PixArtImageProcessor
2727from ...loaders import SanaLoraLoaderMixin
2828from ...models import AutoencoderDC , SanaTransformer2DModel
2929from ...schedulers import DPMSolverMultistepScheduler
4343from ..pixart_alpha .pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
4444from .pipeline_output import SanaPipelineOutput
4545
46+
4647if is_torch_xla_available ():
4748 import torch_xla .core .xla_model as xm
4849
7778
7879# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7980def retrieve_timesteps (
80- scheduler ,
81- num_inference_steps : Optional [int ] = None ,
82- device : Optional [Union [str , torch .device ]] = None ,
83- timesteps : Optional [List [int ]] = None ,
84- sigmas : Optional [List [float ]] = None ,
85- ** kwargs ,
81+ scheduler ,
82+ num_inference_steps : Optional [int ] = None ,
83+ device : Optional [Union [str , torch .device ]] = None ,
84+ timesteps : Optional [List [int ]] = None ,
85+ sigmas : Optional [List [float ]] = None ,
86+ ** kwargs ,
8687):
8788 r"""
8889 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -149,12 +150,12 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
149150 _callback_tensor_inputs = ["latents" , "prompt_embeds" ]
150151
151152 def __init__ (
152- self ,
153- tokenizer : Union [GemmaTokenizer , GemmaTokenizerFast ],
154- text_encoder : Gemma2PreTrainedModel ,
155- vae : AutoencoderDC ,
156- transformer : SanaTransformer2DModel ,
157- scheduler : DPMSolverMultistepScheduler ,
153+ self ,
154+ tokenizer : Union [GemmaTokenizer , GemmaTokenizerFast ],
155+ text_encoder : Gemma2PreTrainedModel ,
156+ vae : AutoencoderDC ,
157+ transformer : SanaTransformer2DModel ,
158+ scheduler : DPMSolverMultistepScheduler ,
158159 ):
159160 super ().__init__ ()
160161
@@ -200,13 +201,13 @@ def disable_vae_tiling(self):
200201
201202 # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
202203 def _get_gemma_prompt_embeds (
203- self ,
204- prompt : Union [str , List [str ]],
205- device : torch .device ,
206- dtype : torch .dtype ,
207- clean_caption : bool = False ,
208- max_sequence_length : int = 300 ,
209- complex_human_instruction : Optional [List [str ]] = None ,
204+ self ,
205+ prompt : Union [str , List [str ]],
206+ device : torch .device ,
207+ dtype : torch .dtype ,
208+ clean_caption : bool = False ,
209+ max_sequence_length : int = 300 ,
210+ complex_human_instruction : Optional [List [str ]] = None ,
210211 ):
211212 r"""
212213 Encodes the prompt into text encoder hidden states.
@@ -258,16 +259,16 @@ def _get_gemma_prompt_embeds(
258259 return prompt_embeds , prompt_attention_mask
259260
260261 def encode_prompt (
261- self ,
262- prompt : Union [str , List [str ]],
263- num_images_per_prompt : int = 1 ,
264- device : Optional [torch .device ] = None ,
265- prompt_embeds : Optional [torch .Tensor ] = None ,
266- prompt_attention_mask : Optional [torch .Tensor ] = None ,
267- clean_caption : bool = False ,
268- max_sequence_length : int = 300 ,
269- complex_human_instruction : Optional [List [str ]] = None ,
270- lora_scale : Optional [float ] = None ,
262+ self ,
263+ prompt : Union [str , List [str ]],
264+ num_images_per_prompt : int = 1 ,
265+ device : Optional [torch .device ] = None ,
266+ prompt_embeds : Optional [torch .Tensor ] = None ,
267+ prompt_attention_mask : Optional [torch .Tensor ] = None ,
268+ clean_caption : bool = False ,
269+ max_sequence_length : int = 300 ,
270+ complex_human_instruction : Optional [List [str ]] = None ,
271+ lora_scale : Optional [float ] = None ,
271272 ):
272273 r"""
273274 Encodes the prompt into text encoder hidden states.
@@ -366,25 +367,25 @@ def get_timesteps(self, num_inference_steps, strength, device):
366367 init_timestep = min (num_inference_steps * strength , num_inference_steps )
367368
368369 t_start = int (max (num_inference_steps - init_timestep , 0 ))
369- timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
370+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
370371 if hasattr (self .scheduler , "set_begin_index" ):
371372 self .scheduler .set_begin_index (t_start * self .scheduler .order )
372373
373374 return timesteps , num_inference_steps - t_start
374375
375376 def check_inputs (
376- self ,
377- prompt ,
378- strength ,
379- height ,
380- width ,
381- num_inference_steps ,
382- timesteps ,
383- max_timesteps ,
384- intermediate_timesteps ,
385- callback_on_step_end_tensor_inputs = None ,
386- prompt_embeds = None ,
387- prompt_attention_mask = None ,
377+ self ,
378+ prompt ,
379+ strength ,
380+ height ,
381+ width ,
382+ num_inference_steps ,
383+ timesteps ,
384+ max_timesteps ,
385+ intermediate_timesteps ,
386+ callback_on_step_end_tensor_inputs = None ,
387+ prompt_embeds = None ,
388+ prompt_attention_mask = None ,
388389 ):
389390 if strength < 0 or strength > 1 :
390391 raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
@@ -393,7 +394,7 @@ def check_inputs(
393394 raise ValueError (f"`height` and `width` have to be divisible by 32 but are { height } and { width } ." )
394395
395396 if callback_on_step_end_tensor_inputs is not None and not all (
396- k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
397+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
397398 ):
398399 raise ValueError (
399400 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
@@ -571,12 +572,12 @@ def _clean_caption(self, caption):
571572
572573 # Copied from diffusers.pipelines.sana.pipeline_sana_controlnet.SanaPipeline.prepare_latents
573574 def prepare_image (
574- self ,
575- image ,
576- width ,
577- height ,
578- device ,
579- dtype ,
575+ self ,
576+ image ,
577+ width ,
578+ height ,
579+ device ,
580+ dtype ,
580581 ):
581582 if isinstance (image , torch .Tensor ):
582583 pass
@@ -588,17 +589,9 @@ def prepare_image(
588589 return image
589590
590591 # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
591- def prepare_latents (self ,
592- image ,
593- timestep ,
594- batch_size ,
595- num_channels_latents ,
596- height ,
597- width ,
598- dtype ,
599- device ,
600- generator ,
601- latents = None ):
592+ def prepare_latents (
593+ self , image , timestep , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None
594+ ):
602595 if latents is not None :
603596 return latents .to (device = device , dtype = dtype )
604597
@@ -609,7 +602,6 @@ def prepare_latents(self,
609602 int (width ) // self .vae_scale_factor ,
610603 )
611604
612-
613605 if image .shape [1 ] != num_channels_latents :
614606 image = self .vae .encode (image ).latent
615607 image_latents = image * self .vae .config .scaling_factor * self .scheduler .config .sigma_data
@@ -657,41 +649,41 @@ def interrupt(self):
657649 @torch .no_grad ()
658650 @replace_example_docstring (EXAMPLE_DOC_STRING )
659651 def __call__ (
660- self ,
661- prompt : Union [str , List [str ]] = None ,
662- num_inference_steps : int = 2 ,
663- timesteps : List [int ] = None ,
664- max_timesteps : float = 1.57080 ,
665- intermediate_timesteps : float = 1.3 ,
666- guidance_scale : float = 4.5 ,
667- image : PipelineImageInput = None ,
668- strength : float = 0.6 ,
669- num_images_per_prompt : Optional [int ] = 1 ,
670- height : int = 1024 ,
671- width : int = 1024 ,
672- eta : float = 0.0 ,
673- generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
674- latents : Optional [torch .Tensor ] = None ,
675- prompt_embeds : Optional [torch .Tensor ] = None ,
676- prompt_attention_mask : Optional [torch .Tensor ] = None ,
677- output_type : Optional [str ] = "pil" ,
678- return_dict : bool = True ,
679- clean_caption : bool = False ,
680- use_resolution_binning : bool = True ,
681- attention_kwargs : Optional [Dict [str , Any ]] = None ,
682- callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
683- callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
684- max_sequence_length : int = 300 ,
685- complex_human_instruction : List [str ] = [
686- "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:" ,
687- "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes." ,
688- "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating." ,
689- "Here are examples of how to transform or refine prompts:" ,
690- "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers." ,
691- "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers." ,
692- "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:" ,
693- "User Prompt: " ,
694- ],
652+ self ,
653+ prompt : Union [str , List [str ]] = None ,
654+ num_inference_steps : int = 2 ,
655+ timesteps : List [int ] = None ,
656+ max_timesteps : float = 1.57080 ,
657+ intermediate_timesteps : float = 1.3 ,
658+ guidance_scale : float = 4.5 ,
659+ image : PipelineImageInput = None ,
660+ strength : float = 0.6 ,
661+ num_images_per_prompt : Optional [int ] = 1 ,
662+ height : int = 1024 ,
663+ width : int = 1024 ,
664+ eta : float = 0.0 ,
665+ generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
666+ latents : Optional [torch .Tensor ] = None ,
667+ prompt_embeds : Optional [torch .Tensor ] = None ,
668+ prompt_attention_mask : Optional [torch .Tensor ] = None ,
669+ output_type : Optional [str ] = "pil" ,
670+ return_dict : bool = True ,
671+ clean_caption : bool = False ,
672+ use_resolution_binning : bool = True ,
673+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
674+ callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
675+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
676+ max_sequence_length : int = 300 ,
677+ complex_human_instruction : List [str ] = [
678+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:" ,
679+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes." ,
680+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating." ,
681+ "Here are examples of how to transform or refine prompts:" ,
682+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers." ,
683+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers." ,
684+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:" ,
685+ "User Prompt: " ,
686+ ],
695687 ) -> Union [SanaPipelineOutput , Tuple ]:
696688 """
697689 Function invoked when calling the pipeline for generation.
@@ -874,7 +866,7 @@ def __call__(
874866 )
875867
876868 # I think this is redundant given the scaling in prepare_latents
877- #latents = latents * self.scheduler.config.sigma_data
869+ # latents = latents * self.scheduler.config.sigma_data
878870
879871 guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
880872 guidance = guidance .expand (latents .shape [0 ]).to (prompt_embeds .dtype )
@@ -902,7 +894,7 @@ def __call__(
902894
903895 scm_timestep_expanded = scm_timestep .view (- 1 , 1 , 1 , 1 )
904896 latent_model_input = latents_model_input * torch .sqrt (
905- scm_timestep_expanded ** 2 + (1 - scm_timestep_expanded ) ** 2
897+ scm_timestep_expanded ** 2 + (1 - scm_timestep_expanded ) ** 2
906898 )
907899
908900 # predict noise model_output
@@ -917,9 +909,9 @@ def __call__(
917909 )[0 ]
918910
919911 noise_pred = (
920- (1 - 2 * scm_timestep_expanded ) * latent_model_input
921- + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded ** 2 ) * noise_pred
922- ) / torch .sqrt (scm_timestep_expanded ** 2 + (1 - scm_timestep_expanded ) ** 2 )
912+ (1 - 2 * scm_timestep_expanded ) * latent_model_input
913+ + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded ** 2 ) * noise_pred
914+ ) / torch .sqrt (scm_timestep_expanded ** 2 + (1 - scm_timestep_expanded ) ** 2 )
923915 noise_pred = noise_pred .float () * self .scheduler .config .sigma_data
924916
925917 # compute previous image: x_t -> x_t-1
0 commit comments