Skip to content

Commit 1ab5d8c

Browse files
committed
update StableDiffusion3Img2ImgPipeline.add image size validation
1 parent 6131a93 commit 1ab5d8c

File tree

1 file changed

+112
-94
lines changed

1 file changed

+112
-94
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 112 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,13 @@
4141
from ..pipeline_utils import DiffusionPipeline
4242
from .pipeline_output import StableDiffusion3PipelineOutput
4343

44-
4544
if is_torch_xla_available():
4645
import torch_xla.core.xla_model as xm
4746

4847
XLA_AVAILABLE = True
4948
else:
5049
XLA_AVAILABLE = False
5150

52-
5351
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5452

5553
EXAMPLE_DOC_STRING = """
@@ -77,7 +75,7 @@
7775

7876
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7977
def retrieve_latents(
80-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
78+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
8179
):
8280
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
8381
return encoder_output.latent_dist.sample(generator)
@@ -91,12 +89,12 @@ def retrieve_latents(
9189

9290
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9391
def retrieve_timesteps(
94-
scheduler,
95-
num_inference_steps: Optional[int] = None,
96-
device: Optional[Union[str, torch.device]] = None,
97-
timesteps: Optional[List[int]] = None,
98-
sigmas: Optional[List[float]] = None,
99-
**kwargs,
92+
scheduler,
93+
num_inference_steps: Optional[int] = None,
94+
device: Optional[Union[str, torch.device]] = None,
95+
timesteps: Optional[List[int]] = None,
96+
sigmas: Optional[List[float]] = None,
97+
**kwargs,
10098
):
10199
r"""
102100
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -188,16 +186,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
188186
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
189187

190188
def __init__(
191-
self,
192-
transformer: SD3Transformer2DModel,
193-
scheduler: FlowMatchEulerDiscreteScheduler,
194-
vae: AutoencoderKL,
195-
text_encoder: CLIPTextModelWithProjection,
196-
tokenizer: CLIPTokenizer,
197-
text_encoder_2: CLIPTextModelWithProjection,
198-
tokenizer_2: CLIPTokenizer,
199-
text_encoder_3: T5EncoderModel,
200-
tokenizer_3: T5TokenizerFast,
189+
self,
190+
transformer: SD3Transformer2DModel,
191+
scheduler: FlowMatchEulerDiscreteScheduler,
192+
vae: AutoencoderKL,
193+
text_encoder: CLIPTextModelWithProjection,
194+
tokenizer: CLIPTokenizer,
195+
text_encoder_2: CLIPTextModelWithProjection,
196+
tokenizer_2: CLIPTokenizer,
197+
text_encoder_3: T5EncoderModel,
198+
tokenizer_3: T5TokenizerFast,
201199
):
202200
super().__init__()
203201

@@ -218,15 +216,18 @@ def __init__(
218216
)
219217
self.tokenizer_max_length = self.tokenizer.model_max_length
220218
self.default_sample_size = self.transformer.config.sample_size
219+
self.patch_size = (
220+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
221+
)
221222

222223
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
223224
def _get_t5_prompt_embeds(
224-
self,
225-
prompt: Union[str, List[str]] = None,
226-
num_images_per_prompt: int = 1,
227-
max_sequence_length: int = 256,
228-
device: Optional[torch.device] = None,
229-
dtype: Optional[torch.dtype] = None,
225+
self,
226+
prompt: Union[str, List[str]] = None,
227+
num_images_per_prompt: int = 1,
228+
max_sequence_length: int = 256,
229+
device: Optional[torch.device] = None,
230+
dtype: Optional[torch.dtype] = None,
230231
):
231232
device = device or self._execution_device
232233
dtype = dtype or self.text_encoder.dtype
@@ -257,7 +258,7 @@ def _get_t5_prompt_embeds(
257258
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
258259

259260
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
260-
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
261+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
261262
logger.warning(
262263
"The following part of your input was truncated because `max_sequence_length` is set to "
263264
f" {max_sequence_length} tokens: {removed_text}"
@@ -278,12 +279,12 @@ def _get_t5_prompt_embeds(
278279

279280
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
280281
def _get_clip_prompt_embeds(
281-
self,
282-
prompt: Union[str, List[str]],
283-
num_images_per_prompt: int = 1,
284-
device: Optional[torch.device] = None,
285-
clip_skip: Optional[int] = None,
286-
clip_model_index: int = 0,
282+
self,
283+
prompt: Union[str, List[str]],
284+
num_images_per_prompt: int = 1,
285+
device: Optional[torch.device] = None,
286+
clip_skip: Optional[int] = None,
287+
clip_model_index: int = 0,
287288
):
288289
device = device or self._execution_device
289290

@@ -307,7 +308,7 @@ def _get_clip_prompt_embeds(
307308
text_input_ids = text_inputs.input_ids
308309
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
309310
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
310-
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
311+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
311312
logger.warning(
312313
"The following part of your input was truncated because CLIP can only handle sequences up to"
313314
f" {self.tokenizer_max_length} tokens: {removed_text}"
@@ -334,23 +335,23 @@ def _get_clip_prompt_embeds(
334335

335336
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
336337
def encode_prompt(
337-
self,
338-
prompt: Union[str, List[str]],
339-
prompt_2: Union[str, List[str]],
340-
prompt_3: Union[str, List[str]],
341-
device: Optional[torch.device] = None,
342-
num_images_per_prompt: int = 1,
343-
do_classifier_free_guidance: bool = True,
344-
negative_prompt: Optional[Union[str, List[str]]] = None,
345-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
346-
negative_prompt_3: Optional[Union[str, List[str]]] = None,
347-
prompt_embeds: Optional[torch.FloatTensor] = None,
348-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
349-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
350-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
351-
clip_skip: Optional[int] = None,
352-
max_sequence_length: int = 256,
353-
lora_scale: Optional[float] = None,
338+
self,
339+
prompt: Union[str, List[str]],
340+
prompt_2: Union[str, List[str]],
341+
prompt_3: Union[str, List[str]],
342+
device: Optional[torch.device] = None,
343+
num_images_per_prompt: int = 1,
344+
do_classifier_free_guidance: bool = True,
345+
negative_prompt: Optional[Union[str, List[str]]] = None,
346+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
347+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
348+
prompt_embeds: Optional[torch.FloatTensor] = None,
349+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
350+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
351+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
352+
clip_skip: Optional[int] = None,
353+
max_sequence_length: int = 256,
354+
lora_scale: Optional[float] = None,
354355
):
355356
r"""
356357
@@ -527,26 +528,37 @@ def encode_prompt(
527528
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
528529

529530
def check_inputs(
530-
self,
531-
prompt,
532-
prompt_2,
533-
prompt_3,
534-
strength,
535-
negative_prompt=None,
536-
negative_prompt_2=None,
537-
negative_prompt_3=None,
538-
prompt_embeds=None,
539-
negative_prompt_embeds=None,
540-
pooled_prompt_embeds=None,
541-
negative_pooled_prompt_embeds=None,
542-
callback_on_step_end_tensor_inputs=None,
543-
max_sequence_length=None,
531+
self,
532+
prompt,
533+
prompt_2,
534+
prompt_3,
535+
height,
536+
width,
537+
strength,
538+
negative_prompt=None,
539+
negative_prompt_2=None,
540+
negative_prompt_3=None,
541+
prompt_embeds=None,
542+
negative_prompt_embeds=None,
543+
pooled_prompt_embeds=None,
544+
negative_pooled_prompt_embeds=None,
545+
callback_on_step_end_tensor_inputs=None,
546+
max_sequence_length=None,
544547
):
548+
if (
549+
height % (self.vae_scale_factor * self.patch_size) != 0
550+
or width % (self.vae_scale_factor * self.patch_size) != 0
551+
):
552+
raise ValueError(
553+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
554+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
555+
)
556+
545557
if strength < 0 or strength > 1:
546558
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
547559

548560
if callback_on_step_end_tensor_inputs is not None and not all(
549-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
561+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
550562
):
551563
raise ValueError(
552564
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -620,7 +632,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
620632
init_timestep = min(num_inference_steps * strength, num_inference_steps)
621633

622634
t_start = int(max(num_inference_steps - init_timestep, 0))
623-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
635+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
624636
if hasattr(self.scheduler, "set_begin_index"):
625637
self.scheduler.set_begin_index(t_start * self.scheduler.order)
626638

@@ -647,7 +659,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
647659

648660
elif isinstance(generator, list):
649661
init_latents = [
650-
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
662+
retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i])
651663
for i in range(batch_size)
652664
]
653665
init_latents = torch.cat(init_latents, dim=0)
@@ -706,32 +718,34 @@ def interrupt(self):
706718
@torch.no_grad()
707719
@replace_example_docstring(EXAMPLE_DOC_STRING)
708720
def __call__(
709-
self,
710-
prompt: Union[str, List[str]] = None,
711-
prompt_2: Optional[Union[str, List[str]]] = None,
712-
prompt_3: Optional[Union[str, List[str]]] = None,
713-
image: PipelineImageInput = None,
714-
strength: float = 0.6,
715-
num_inference_steps: int = 50,
716-
sigmas: Optional[List[float]] = None,
717-
guidance_scale: float = 7.0,
718-
negative_prompt: Optional[Union[str, List[str]]] = None,
719-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
720-
negative_prompt_3: Optional[Union[str, List[str]]] = None,
721-
num_images_per_prompt: Optional[int] = 1,
722-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
723-
latents: Optional[torch.FloatTensor] = None,
724-
prompt_embeds: Optional[torch.FloatTensor] = None,
725-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
726-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
727-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
728-
output_type: Optional[str] = "pil",
729-
return_dict: bool = True,
730-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
731-
clip_skip: Optional[int] = None,
732-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
733-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
734-
max_sequence_length: int = 256,
721+
self,
722+
prompt: Union[str, List[str]] = None,
723+
prompt_2: Optional[Union[str, List[str]]] = None,
724+
prompt_3: Optional[Union[str, List[str]]] = None,
725+
height: Optional[int] = None,
726+
width: Optional[int] = None,
727+
image: PipelineImageInput = None,
728+
strength: float = 0.6,
729+
num_inference_steps: int = 50,
730+
sigmas: Optional[List[float]] = None,
731+
guidance_scale: float = 7.0,
732+
negative_prompt: Optional[Union[str, List[str]]] = None,
733+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
734+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
735+
num_images_per_prompt: Optional[int] = 1,
736+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
737+
latents: Optional[torch.FloatTensor] = None,
738+
prompt_embeds: Optional[torch.FloatTensor] = None,
739+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
740+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
741+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
742+
output_type: Optional[str] = "pil",
743+
return_dict: bool = True,
744+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
745+
clip_skip: Optional[int] = None,
746+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
747+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
748+
max_sequence_length: int = 256,
735749
):
736750
r"""
737751
Function invoked when calling the pipeline for generation.
@@ -824,12 +838,16 @@ def __call__(
824838
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
825839
`tuple`. When returning a tuple, the first element is a list with the generated images.
826840
"""
841+
height = height or self.default_sample_size * self.vae_scale_factor
842+
width = width or self.default_sample_size * self.vae_scale_factor
827843

828844
# 1. Check inputs. Raise error if not correct
829845
self.check_inputs(
830846
prompt,
831847
prompt_2,
832848
prompt_3,
849+
height,
850+
width,
833851
strength,
834852
negative_prompt=negative_prompt,
835853
negative_prompt_2=negative_prompt_2,
@@ -890,7 +908,7 @@ def __call__(
890908
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
891909

892910
# 3. Preprocess image
893-
image = self.image_processor.preprocess(image)
911+
image = self.image_processor.preprocess(image, height, width)
894912

895913
# 4. Prepare timesteps
896914
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)

0 commit comments

Comments
 (0)