4141from ..pipeline_utils import DiffusionPipeline
4242from .pipeline_output import StableDiffusion3PipelineOutput
4343
44-
4544if is_torch_xla_available ():
4645 import torch_xla .core .xla_model as xm
4746
4847 XLA_AVAILABLE = True
4948else :
5049 XLA_AVAILABLE = False
5150
52-
5351logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
5452
5553EXAMPLE_DOC_STRING = """
7775
7876# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7977def retrieve_latents (
80- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
78+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
8179):
8280 if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
8381 return encoder_output .latent_dist .sample (generator )
@@ -91,12 +89,12 @@ def retrieve_latents(
9189
9290# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9391def retrieve_timesteps (
94- scheduler ,
95- num_inference_steps : Optional [int ] = None ,
96- device : Optional [Union [str , torch .device ]] = None ,
97- timesteps : Optional [List [int ]] = None ,
98- sigmas : Optional [List [float ]] = None ,
99- ** kwargs ,
92+ scheduler ,
93+ num_inference_steps : Optional [int ] = None ,
94+ device : Optional [Union [str , torch .device ]] = None ,
95+ timesteps : Optional [List [int ]] = None ,
96+ sigmas : Optional [List [float ]] = None ,
97+ ** kwargs ,
10098):
10199 r"""
102100 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -188,16 +186,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
188186 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
189187
190188 def __init__ (
191- self ,
192- transformer : SD3Transformer2DModel ,
193- scheduler : FlowMatchEulerDiscreteScheduler ,
194- vae : AutoencoderKL ,
195- text_encoder : CLIPTextModelWithProjection ,
196- tokenizer : CLIPTokenizer ,
197- text_encoder_2 : CLIPTextModelWithProjection ,
198- tokenizer_2 : CLIPTokenizer ,
199- text_encoder_3 : T5EncoderModel ,
200- tokenizer_3 : T5TokenizerFast ,
189+ self ,
190+ transformer : SD3Transformer2DModel ,
191+ scheduler : FlowMatchEulerDiscreteScheduler ,
192+ vae : AutoencoderKL ,
193+ text_encoder : CLIPTextModelWithProjection ,
194+ tokenizer : CLIPTokenizer ,
195+ text_encoder_2 : CLIPTextModelWithProjection ,
196+ tokenizer_2 : CLIPTokenizer ,
197+ text_encoder_3 : T5EncoderModel ,
198+ tokenizer_3 : T5TokenizerFast ,
201199 ):
202200 super ().__init__ ()
203201
@@ -218,15 +216,18 @@ def __init__(
218216 )
219217 self .tokenizer_max_length = self .tokenizer .model_max_length
220218 self .default_sample_size = self .transformer .config .sample_size
219+ self .patch_size = (
220+ self .transformer .config .patch_size if hasattr (self , "transformer" ) and self .transformer is not None else 2
221+ )
221222
222223 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
223224 def _get_t5_prompt_embeds (
224- self ,
225- prompt : Union [str , List [str ]] = None ,
226- num_images_per_prompt : int = 1 ,
227- max_sequence_length : int = 256 ,
228- device : Optional [torch .device ] = None ,
229- dtype : Optional [torch .dtype ] = None ,
225+ self ,
226+ prompt : Union [str , List [str ]] = None ,
227+ num_images_per_prompt : int = 1 ,
228+ max_sequence_length : int = 256 ,
229+ device : Optional [torch .device ] = None ,
230+ dtype : Optional [torch .dtype ] = None ,
230231 ):
231232 device = device or self ._execution_device
232233 dtype = dtype or self .text_encoder .dtype
@@ -257,7 +258,7 @@ def _get_t5_prompt_embeds(
257258 untruncated_ids = self .tokenizer_3 (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
258259
259260 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
260- removed_text = self .tokenizer_3 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
261+ removed_text = self .tokenizer_3 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
261262 logger .warning (
262263 "The following part of your input was truncated because `max_sequence_length` is set to "
263264 f" { max_sequence_length } tokens: { removed_text } "
@@ -278,12 +279,12 @@ def _get_t5_prompt_embeds(
278279
279280 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
280281 def _get_clip_prompt_embeds (
281- self ,
282- prompt : Union [str , List [str ]],
283- num_images_per_prompt : int = 1 ,
284- device : Optional [torch .device ] = None ,
285- clip_skip : Optional [int ] = None ,
286- clip_model_index : int = 0 ,
282+ self ,
283+ prompt : Union [str , List [str ]],
284+ num_images_per_prompt : int = 1 ,
285+ device : Optional [torch .device ] = None ,
286+ clip_skip : Optional [int ] = None ,
287+ clip_model_index : int = 0 ,
287288 ):
288289 device = device or self ._execution_device
289290
@@ -307,7 +308,7 @@ def _get_clip_prompt_embeds(
307308 text_input_ids = text_inputs .input_ids
308309 untruncated_ids = tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
309310 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
310- removed_text = tokenizer .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
311+ removed_text = tokenizer .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
311312 logger .warning (
312313 "The following part of your input was truncated because CLIP can only handle sequences up to"
313314 f" { self .tokenizer_max_length } tokens: { removed_text } "
@@ -334,23 +335,23 @@ def _get_clip_prompt_embeds(
334335
335336 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
336337 def encode_prompt (
337- self ,
338- prompt : Union [str , List [str ]],
339- prompt_2 : Union [str , List [str ]],
340- prompt_3 : Union [str , List [str ]],
341- device : Optional [torch .device ] = None ,
342- num_images_per_prompt : int = 1 ,
343- do_classifier_free_guidance : bool = True ,
344- negative_prompt : Optional [Union [str , List [str ]]] = None ,
345- negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
346- negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
347- prompt_embeds : Optional [torch .FloatTensor ] = None ,
348- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
349- pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
350- negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
351- clip_skip : Optional [int ] = None ,
352- max_sequence_length : int = 256 ,
353- lora_scale : Optional [float ] = None ,
338+ self ,
339+ prompt : Union [str , List [str ]],
340+ prompt_2 : Union [str , List [str ]],
341+ prompt_3 : Union [str , List [str ]],
342+ device : Optional [torch .device ] = None ,
343+ num_images_per_prompt : int = 1 ,
344+ do_classifier_free_guidance : bool = True ,
345+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
346+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
347+ negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
348+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
349+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
350+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
351+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
352+ clip_skip : Optional [int ] = None ,
353+ max_sequence_length : int = 256 ,
354+ lora_scale : Optional [float ] = None ,
354355 ):
355356 r"""
356357
@@ -527,26 +528,37 @@ def encode_prompt(
527528 return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
528529
529530 def check_inputs (
530- self ,
531- prompt ,
532- prompt_2 ,
533- prompt_3 ,
534- strength ,
535- negative_prompt = None ,
536- negative_prompt_2 = None ,
537- negative_prompt_3 = None ,
538- prompt_embeds = None ,
539- negative_prompt_embeds = None ,
540- pooled_prompt_embeds = None ,
541- negative_pooled_prompt_embeds = None ,
542- callback_on_step_end_tensor_inputs = None ,
543- max_sequence_length = None ,
531+ self ,
532+ prompt ,
533+ prompt_2 ,
534+ prompt_3 ,
535+ height ,
536+ width ,
537+ strength ,
538+ negative_prompt = None ,
539+ negative_prompt_2 = None ,
540+ negative_prompt_3 = None ,
541+ prompt_embeds = None ,
542+ negative_prompt_embeds = None ,
543+ pooled_prompt_embeds = None ,
544+ negative_pooled_prompt_embeds = None ,
545+ callback_on_step_end_tensor_inputs = None ,
546+ max_sequence_length = None ,
544547 ):
548+ if (
549+ height % (self .vae_scale_factor * self .patch_size ) != 0
550+ or width % (self .vae_scale_factor * self .patch_size ) != 0
551+ ):
552+ raise ValueError (
553+ f"`height` and `width` have to be divisible by { self .vae_scale_factor * self .patch_size } but are { height } and { width } ."
554+ f"You can use height { height - height % (self .vae_scale_factor * self .patch_size )} and width { width - width % (self .vae_scale_factor * self .patch_size )} ."
555+ )
556+
545557 if strength < 0 or strength > 1 :
546558 raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
547559
548560 if callback_on_step_end_tensor_inputs is not None and not all (
549- k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
561+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
550562 ):
551563 raise ValueError (
552564 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
@@ -620,7 +632,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
620632 init_timestep = min (num_inference_steps * strength , num_inference_steps )
621633
622634 t_start = int (max (num_inference_steps - init_timestep , 0 ))
623- timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
635+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
624636 if hasattr (self .scheduler , "set_begin_index" ):
625637 self .scheduler .set_begin_index (t_start * self .scheduler .order )
626638
@@ -647,7 +659,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
647659
648660 elif isinstance (generator , list ):
649661 init_latents = [
650- retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
662+ retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
651663 for i in range (batch_size )
652664 ]
653665 init_latents = torch .cat (init_latents , dim = 0 )
@@ -706,32 +718,34 @@ def interrupt(self):
706718 @torch .no_grad ()
707719 @replace_example_docstring (EXAMPLE_DOC_STRING )
708720 def __call__ (
709- self ,
710- prompt : Union [str , List [str ]] = None ,
711- prompt_2 : Optional [Union [str , List [str ]]] = None ,
712- prompt_3 : Optional [Union [str , List [str ]]] = None ,
713- image : PipelineImageInput = None ,
714- strength : float = 0.6 ,
715- num_inference_steps : int = 50 ,
716- sigmas : Optional [List [float ]] = None ,
717- guidance_scale : float = 7.0 ,
718- negative_prompt : Optional [Union [str , List [str ]]] = None ,
719- negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
720- negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
721- num_images_per_prompt : Optional [int ] = 1 ,
722- generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
723- latents : Optional [torch .FloatTensor ] = None ,
724- prompt_embeds : Optional [torch .FloatTensor ] = None ,
725- negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
726- pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
727- negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
728- output_type : Optional [str ] = "pil" ,
729- return_dict : bool = True ,
730- joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
731- clip_skip : Optional [int ] = None ,
732- callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
733- callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
734- max_sequence_length : int = 256 ,
721+ self ,
722+ prompt : Union [str , List [str ]] = None ,
723+ prompt_2 : Optional [Union [str , List [str ]]] = None ,
724+ prompt_3 : Optional [Union [str , List [str ]]] = None ,
725+ height : Optional [int ] = None ,
726+ width : Optional [int ] = None ,
727+ image : PipelineImageInput = None ,
728+ strength : float = 0.6 ,
729+ num_inference_steps : int = 50 ,
730+ sigmas : Optional [List [float ]] = None ,
731+ guidance_scale : float = 7.0 ,
732+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
733+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
734+ negative_prompt_3 : Optional [Union [str , List [str ]]] = None ,
735+ num_images_per_prompt : Optional [int ] = 1 ,
736+ generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
737+ latents : Optional [torch .FloatTensor ] = None ,
738+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
739+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
740+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
741+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
742+ output_type : Optional [str ] = "pil" ,
743+ return_dict : bool = True ,
744+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
745+ clip_skip : Optional [int ] = None ,
746+ callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
747+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
748+ max_sequence_length : int = 256 ,
735749 ):
736750 r"""
737751 Function invoked when calling the pipeline for generation.
@@ -824,12 +838,16 @@ def __call__(
824838 [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
825839 `tuple`. When returning a tuple, the first element is a list with the generated images.
826840 """
841+ height = height or self .default_sample_size * self .vae_scale_factor
842+ width = width or self .default_sample_size * self .vae_scale_factor
827843
828844 # 1. Check inputs. Raise error if not correct
829845 self .check_inputs (
830846 prompt ,
831847 prompt_2 ,
832848 prompt_3 ,
849+ height ,
850+ width ,
833851 strength ,
834852 negative_prompt = negative_prompt ,
835853 negative_prompt_2 = negative_prompt_2 ,
@@ -890,7 +908,7 @@ def __call__(
890908 pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
891909
892910 # 3. Preprocess image
893- image = self .image_processor .preprocess (image )
911+ image = self .image_processor .preprocess (image , height , width )
894912
895913 # 4. Prepare timesteps
896914 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas )
0 commit comments