4141from ..pipeline_utils import DiffusionPipeline
4242from .pipeline_output import StableDiffusion3PipelineOutput
4343
44+
4445if is_torch_xla_available ():
4546 import torch_xla .core .xla_model as xm
4647
7576
7677# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7778def retrieve_latents (
78- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
79+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
7980):
8081 if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
8182 return encoder_output .latent_dist .sample (generator )
@@ -89,12 +90,12 @@ def retrieve_latents(
8990
9091# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9192def retrieve_timesteps (
92- scheduler ,
93- num_inference_steps : Optional [int ] = None ,
94- device : Optional [Union [str , torch .device ]] = None ,
95- timesteps : Optional [List [int ]] = None ,
96- sigmas : Optional [List [float ]] = None ,
97- ** kwargs ,
93+ scheduler ,
94+ num_inference_steps : Optional [int ] = None ,
95+ device : Optional [Union [str , torch .device ]] = None ,
96+ timesteps : Optional [List [int ]] = None ,
97+ sigmas : Optional [List [float ]] = None ,
98+ ** kwargs ,
9899):
99100 r"""
100101 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -186,16 +187,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
186187 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
187188
188189 def __init__ (
189- self ,
190- transformer : SD3Transformer2DModel ,
191- scheduler : FlowMatchEulerDiscreteScheduler ,
192- vae : AutoencoderKL ,
193- text_encoder : CLIPTextModelWithProjection ,
194- tokenizer : CLIPTokenizer ,
195- text_encoder_2 : CLIPTextModelWithProjection ,
196- tokenizer_2 : CLIPTokenizer ,
197- text_encoder_3 : T5EncoderModel ,
198- tokenizer_3 : T5TokenizerFast ,
190+ self ,
191+ transformer : SD3Transformer2DModel ,
192+ scheduler : FlowMatchEulerDiscreteScheduler ,
193+ vae : AutoencoderKL ,
194+ text_encoder : CLIPTextModelWithProjection ,
195+ tokenizer : CLIPTokenizer ,
196+ text_encoder_2 : CLIPTextModelWithProjection ,
197+ tokenizer_2 : CLIPTokenizer ,
198+ text_encoder_3 : T5EncoderModel ,
199+ tokenizer_3 : T5TokenizerFast ,
199200 ):
200201 super ().__init__ ()
201202
@@ -222,12 +223,12 @@ def __init__(
222223
223224 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
224225 def _get_t5_prompt_embeds (
225- self ,
226- prompt : Union [str , List [str ]] = None ,
227- num_images_per_prompt : int = 1 ,
228- max_sequence_length : int = 256 ,
229- device : Optional [torch .device ] = None ,
230- dtype : Optional [torch .dtype ] = None ,
226+ self ,
227+ prompt : Union [str , List [str ]] = None ,
228+ num_images_per_prompt : int = 1 ,
229+ max_sequence_length : int = 256 ,
230+ device : Optional [torch .device ] = None ,
231+ dtype : Optional [torch .dtype ] = None ,
231232 ):
232233 device = device or self ._execution_device
233234 dtype = dtype or self .text_encoder .dtype
@@ -258,7 +259,7 @@ def _get_t5_prompt_embeds(
258259 untruncated_ids = self .tokenizer_3 (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
259260
260261 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
261- removed_text = self .tokenizer_3 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
262+ removed_text = self .tokenizer_3 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
262263 logger .warning (
263264 "The following part of your input was truncated because `max_sequence_length` is set to "
264265 f" { max_sequence_length } tokens: { removed_text } "
@@ -279,12 +280,12 @@ def _get_t5_prompt_embeds(
279280
280281 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
281282 def _get_clip_prompt_embeds (
282- self ,
283- prompt : Union [str , List [str ]],
284- num_images_per_prompt : int = 1 ,
285- device : Optional [torch .device ] = None ,
286- clip_skip : Optional [int ] = None ,
287- clip_model_index : int = 0 ,
283+ self ,
284+ prompt : Union [str , List [str ]],
285+ num_images_per_prompt : int = 1 ,
286+ device : Optional [torch .device ] = None ,
287+ clip_skip : Optional [int ] = None ,
288+ clip_model_index : int = 0 ,
288289 ):
289290 device = device or self ._execution_device
290291
@@ -308,7 +309,7 @@ def _get_clip_prompt_embeds(
308309 text_input_ids = text_inputs .input_ids
309310 untruncated_ids = tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
310311 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
311- removed_text = tokenizer .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
312+ removed_text = tokenizer .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
312313 logger .warning (
313314 "The following part of your input was truncated because CLIP can only handle sequences up to"
314315 f" { self .tokenizer_max_length } tokens: { removed_text } "
@@ -335,23 +336,23 @@ def _get_clip_prompt_embeds(
335336
336337 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
337338 def encode_prompt (
338- self ,
339- prompt : Union [str , List [str ]],
340- prompt_2 : Union [str , List [str ]],
341- prompt_3 : Union [str , List [str ]],
342- device : Optional [torch .device ] = None ,
343- num_images_per_prompt : int = 1 ,
344- do_classifier_free_guidance : bool = True ,
345- negative_prompt : Optional [Union [str , List [str ]]] = None ,
346- negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
347- negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
348- prompt_embeds : Optional [torch .FloatTensor ] = None ,
349- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
350- pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
351- negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
352- clip_skip : Optional [int ] = None ,
353- max_sequence_length : int = 256 ,
354- lora_scale : Optional [float ] = None ,
339+ self ,
340+ prompt : Union [str , List [str ]],
341+ prompt_2 : Union [str , List [str ]],
342+ prompt_3 : Union [str , List [str ]],
343+ device : Optional [torch .device ] = None ,
344+ num_images_per_prompt : int = 1 ,
345+ do_classifier_free_guidance : bool = True ,
346+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
347+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
348+ negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
349+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
350+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
351+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
352+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
353+ clip_skip : Optional [int ] = None ,
354+ max_sequence_length : int = 256 ,
355+ lora_scale : Optional [float ] = None ,
355356 ):
356357 r"""
357358
@@ -528,26 +529,26 @@ def encode_prompt(
528529 return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
529530
530531 def check_inputs (
531- self ,
532- prompt ,
533- prompt_2 ,
534- prompt_3 ,
535- height ,
536- width ,
537- strength ,
538- negative_prompt = None ,
539- negative_prompt_2 = None ,
540- negative_prompt_3 = None ,
541- prompt_embeds = None ,
542- negative_prompt_embeds = None ,
543- pooled_prompt_embeds = None ,
544- negative_pooled_prompt_embeds = None ,
545- callback_on_step_end_tensor_inputs = None ,
546- max_sequence_length = None ,
532+ self ,
533+ prompt ,
534+ prompt_2 ,
535+ prompt_3 ,
536+ height ,
537+ width ,
538+ strength ,
539+ negative_prompt = None ,
540+ negative_prompt_2 = None ,
541+ negative_prompt_3 = None ,
542+ prompt_embeds = None ,
543+ negative_prompt_embeds = None ,
544+ pooled_prompt_embeds = None ,
545+ negative_pooled_prompt_embeds = None ,
546+ callback_on_step_end_tensor_inputs = None ,
547+ max_sequence_length = None ,
547548 ):
548549 if (
549- height % (self .vae_scale_factor * self .patch_size ) != 0
550- or width % (self .vae_scale_factor * self .patch_size ) != 0
550+ height % (self .vae_scale_factor * self .patch_size ) != 0
551+ or width % (self .vae_scale_factor * self .patch_size ) != 0
551552 ):
552553 raise ValueError (
553554 f"`height` and `width` have to be divisible by { self .vae_scale_factor * self .patch_size } but are { height } and { width } ."
@@ -558,7 +559,7 @@ def check_inputs(
558559 raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
559560
560561 if callback_on_step_end_tensor_inputs is not None and not all (
561- k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
562+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
562563 ):
563564 raise ValueError (
564565 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
@@ -632,7 +633,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
632633 init_timestep = min (num_inference_steps * strength , num_inference_steps )
633634
634635 t_start = int (max (num_inference_steps - init_timestep , 0 ))
635- timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
636+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
636637 if hasattr (self .scheduler , "set_begin_index" ):
637638 self .scheduler .set_begin_index (t_start * self .scheduler .order )
638639
@@ -659,7 +660,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
659660
660661 elif isinstance (generator , list ):
661662 init_latents = [
662- retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
663+ retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
663664 for i in range (batch_size )
664665 ]
665666 init_latents = torch .cat (init_latents , dim = 0 )
@@ -718,34 +719,34 @@ def interrupt(self):
718719 @torch .no_grad ()
719720 @replace_example_docstring (EXAMPLE_DOC_STRING )
720721 def __call__ (
721- self ,
722- prompt : Union [str , List [str ]] = None ,
723- prompt_2 : Optional [Union [str , List [str ]]] = None ,
724- prompt_3 : Optional [Union [str , List [str ]]] = None ,
725- height : Optional [int ] = None ,
726- width : Optional [int ] = None ,
727- image : PipelineImageInput = None ,
728- strength : float = 0.6 ,
729- num_inference_steps : int = 50 ,
730- sigmas : Optional [List [float ]] = None ,
731- guidance_scale : float = 7.0 ,
732- negative_prompt : Optional [Union [str , List [str ]]] = None ,
733- negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
734- negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
735- num_images_per_prompt : Optional [int ] = 1 ,
736- generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
737- latents : Optional [torch .FloatTensor ] = None ,
738- prompt_embeds : Optional [torch .FloatTensor ] = None ,
739- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
740- pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
741- negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
742- output_type : Optional [str ] = "pil" ,
743- return_dict : bool = True ,
744- joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
745- clip_skip : Optional [int ] = None ,
746- callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
747- callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
748- max_sequence_length : int = 256 ,
722+ self ,
723+ prompt : Union [str , List [str ]] = None ,
724+ prompt_2 : Optional [Union [str , List [str ]]] = None ,
725+ prompt_3 : Optional [Union [str , List [str ]]] = None ,
726+ height : Optional [int ] = None ,
727+ width : Optional [int ] = None ,
728+ image : PipelineImageInput = None ,
729+ strength : float = 0.6 ,
730+ num_inference_steps : int = 50 ,
731+ sigmas : Optional [List [float ]] = None ,
732+ guidance_scale : float = 7.0 ,
733+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
734+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
735+ negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
736+ num_images_per_prompt : Optional [int ] = 1 ,
737+ generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
738+ latents : Optional [torch .FloatTensor ] = None ,
739+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
740+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
741+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
742+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
743+ output_type : Optional [str ] = "pil" ,
744+ return_dict : bool = True ,
745+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
746+ clip_skip : Optional [int ] = None ,
747+ callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
748+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
749+ max_sequence_length : int = 256 ,
749750 ):
750751 r"""
751752 Function invoked when calling the pipeline for generation.
0 commit comments