Skip to content

Commit fe32ade

Browse files
authored
make style
1 parent 1ab5d8c commit fe32ade

File tree

1 file changed

+98
-97
lines changed

1 file changed

+98
-97
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 98 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from ..pipeline_utils import DiffusionPipeline
4242
from .pipeline_output import StableDiffusion3PipelineOutput
4343

44+
4445
if is_torch_xla_available():
4546
import torch_xla.core.xla_model as xm
4647

@@ -75,7 +76,7 @@
7576

7677
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7778
def retrieve_latents(
78-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
79+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
7980
):
8081
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
8182
return encoder_output.latent_dist.sample(generator)
@@ -89,12 +90,12 @@ def retrieve_latents(
8990

9091
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9192
def retrieve_timesteps(
92-
scheduler,
93-
num_inference_steps: Optional[int] = None,
94-
device: Optional[Union[str, torch.device]] = None,
95-
timesteps: Optional[List[int]] = None,
96-
sigmas: Optional[List[float]] = None,
97-
**kwargs,
93+
scheduler,
94+
num_inference_steps: Optional[int] = None,
95+
device: Optional[Union[str, torch.device]] = None,
96+
timesteps: Optional[List[int]] = None,
97+
sigmas: Optional[List[float]] = None,
98+
**kwargs,
9899
):
99100
r"""
100101
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -186,16 +187,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
186187
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
187188

188189
def __init__(
189-
self,
190-
transformer: SD3Transformer2DModel,
191-
scheduler: FlowMatchEulerDiscreteScheduler,
192-
vae: AutoencoderKL,
193-
text_encoder: CLIPTextModelWithProjection,
194-
tokenizer: CLIPTokenizer,
195-
text_encoder_2: CLIPTextModelWithProjection,
196-
tokenizer_2: CLIPTokenizer,
197-
text_encoder_3: T5EncoderModel,
198-
tokenizer_3: T5TokenizerFast,
190+
self,
191+
transformer: SD3Transformer2DModel,
192+
scheduler: FlowMatchEulerDiscreteScheduler,
193+
vae: AutoencoderKL,
194+
text_encoder: CLIPTextModelWithProjection,
195+
tokenizer: CLIPTokenizer,
196+
text_encoder_2: CLIPTextModelWithProjection,
197+
tokenizer_2: CLIPTokenizer,
198+
text_encoder_3: T5EncoderModel,
199+
tokenizer_3: T5TokenizerFast,
199200
):
200201
super().__init__()
201202

@@ -222,12 +223,12 @@ def __init__(
222223

223224
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
224225
def _get_t5_prompt_embeds(
225-
self,
226-
prompt: Union[str, List[str]] = None,
227-
num_images_per_prompt: int = 1,
228-
max_sequence_length: int = 256,
229-
device: Optional[torch.device] = None,
230-
dtype: Optional[torch.dtype] = None,
226+
self,
227+
prompt: Union[str, List[str]] = None,
228+
num_images_per_prompt: int = 1,
229+
max_sequence_length: int = 256,
230+
device: Optional[torch.device] = None,
231+
dtype: Optional[torch.dtype] = None,
231232
):
232233
device = device or self._execution_device
233234
dtype = dtype or self.text_encoder.dtype
@@ -258,7 +259,7 @@ def _get_t5_prompt_embeds(
258259
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
259260

260261
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
261-
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
262+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
262263
logger.warning(
263264
"The following part of your input was truncated because `max_sequence_length` is set to "
264265
f" {max_sequence_length} tokens: {removed_text}"
@@ -279,12 +280,12 @@ def _get_t5_prompt_embeds(
279280

280281
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
281282
def _get_clip_prompt_embeds(
282-
self,
283-
prompt: Union[str, List[str]],
284-
num_images_per_prompt: int = 1,
285-
device: Optional[torch.device] = None,
286-
clip_skip: Optional[int] = None,
287-
clip_model_index: int = 0,
283+
self,
284+
prompt: Union[str, List[str]],
285+
num_images_per_prompt: int = 1,
286+
device: Optional[torch.device] = None,
287+
clip_skip: Optional[int] = None,
288+
clip_model_index: int = 0,
288289
):
289290
device = device or self._execution_device
290291

@@ -308,7 +309,7 @@ def _get_clip_prompt_embeds(
308309
text_input_ids = text_inputs.input_ids
309310
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
310311
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
311-
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
312+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
312313
logger.warning(
313314
"The following part of your input was truncated because CLIP can only handle sequences up to"
314315
f" {self.tokenizer_max_length} tokens: {removed_text}"
@@ -335,23 +336,23 @@ def _get_clip_prompt_embeds(
335336

336337
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
337338
def encode_prompt(
338-
self,
339-
prompt: Union[str, List[str]],
340-
prompt_2: Union[str, List[str]],
341-
prompt_3: Union[str, List[str]],
342-
device: Optional[torch.device] = None,
343-
num_images_per_prompt: int = 1,
344-
do_classifier_free_guidance: bool = True,
345-
negative_prompt: Optional[Union[str, List[str]]] = None,
346-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
347-
negative_prompt_3: Optional[Union[str, List[str]]] = None,
348-
prompt_embeds: Optional[torch.FloatTensor] = None,
349-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
350-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
351-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
352-
clip_skip: Optional[int] = None,
353-
max_sequence_length: int = 256,
354-
lora_scale: Optional[float] = None,
339+
self,
340+
prompt: Union[str, List[str]],
341+
prompt_2: Union[str, List[str]],
342+
prompt_3: Union[str, List[str]],
343+
device: Optional[torch.device] = None,
344+
num_images_per_prompt: int = 1,
345+
do_classifier_free_guidance: bool = True,
346+
negative_prompt: Optional[Union[str, List[str]]] = None,
347+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
348+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
349+
prompt_embeds: Optional[torch.FloatTensor] = None,
350+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
351+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
352+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
353+
clip_skip: Optional[int] = None,
354+
max_sequence_length: int = 256,
355+
lora_scale: Optional[float] = None,
355356
):
356357
r"""
357358
@@ -528,26 +529,26 @@ def encode_prompt(
528529
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
529530

530531
def check_inputs(
531-
self,
532-
prompt,
533-
prompt_2,
534-
prompt_3,
535-
height,
536-
width,
537-
strength,
538-
negative_prompt=None,
539-
negative_prompt_2=None,
540-
negative_prompt_3=None,
541-
prompt_embeds=None,
542-
negative_prompt_embeds=None,
543-
pooled_prompt_embeds=None,
544-
negative_pooled_prompt_embeds=None,
545-
callback_on_step_end_tensor_inputs=None,
546-
max_sequence_length=None,
532+
self,
533+
prompt,
534+
prompt_2,
535+
prompt_3,
536+
height,
537+
width,
538+
strength,
539+
negative_prompt=None,
540+
negative_prompt_2=None,
541+
negative_prompt_3=None,
542+
prompt_embeds=None,
543+
negative_prompt_embeds=None,
544+
pooled_prompt_embeds=None,
545+
negative_pooled_prompt_embeds=None,
546+
callback_on_step_end_tensor_inputs=None,
547+
max_sequence_length=None,
547548
):
548549
if (
549-
height % (self.vae_scale_factor * self.patch_size) != 0
550-
or width % (self.vae_scale_factor * self.patch_size) != 0
550+
height % (self.vae_scale_factor * self.patch_size) != 0
551+
or width % (self.vae_scale_factor * self.patch_size) != 0
551552
):
552553
raise ValueError(
553554
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
@@ -558,7 +559,7 @@ def check_inputs(
558559
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
559560

560561
if callback_on_step_end_tensor_inputs is not None and not all(
561-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
562+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
562563
):
563564
raise ValueError(
564565
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -632,7 +633,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
632633
init_timestep = min(num_inference_steps * strength, num_inference_steps)
633634

634635
t_start = int(max(num_inference_steps - init_timestep, 0))
635-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
636+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
636637
if hasattr(self.scheduler, "set_begin_index"):
637638
self.scheduler.set_begin_index(t_start * self.scheduler.order)
638639

@@ -659,7 +660,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
659660

660661
elif isinstance(generator, list):
661662
init_latents = [
662-
retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i])
663+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
663664
for i in range(batch_size)
664665
]
665666
init_latents = torch.cat(init_latents, dim=0)
@@ -718,34 +719,34 @@ def interrupt(self):
718719
@torch.no_grad()
719720
@replace_example_docstring(EXAMPLE_DOC_STRING)
720721
def __call__(
721-
self,
722-
prompt: Union[str, List[str]] = None,
723-
prompt_2: Optional[Union[str, List[str]]] = None,
724-
prompt_3: Optional[Union[str, List[str]]] = None,
725-
height: Optional[int] = None,
726-
width: Optional[int] = None,
727-
image: PipelineImageInput = None,
728-
strength: float = 0.6,
729-
num_inference_steps: int = 50,
730-
sigmas: Optional[List[float]] = None,
731-
guidance_scale: float = 7.0,
732-
negative_prompt: Optional[Union[str, List[str]]] = None,
733-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
734-
negative_prompt_3: Optional[Union[str, List[str]]] = None,
735-
num_images_per_prompt: Optional[int] = 1,
736-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
737-
latents: Optional[torch.FloatTensor] = None,
738-
prompt_embeds: Optional[torch.FloatTensor] = None,
739-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
740-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
741-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
742-
output_type: Optional[str] = "pil",
743-
return_dict: bool = True,
744-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
745-
clip_skip: Optional[int] = None,
746-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
747-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
748-
max_sequence_length: int = 256,
722+
self,
723+
prompt: Union[str, List[str]] = None,
724+
prompt_2: Optional[Union[str, List[str]]] = None,
725+
prompt_3: Optional[Union[str, List[str]]] = None,
726+
height: Optional[int] = None,
727+
width: Optional[int] = None,
728+
image: PipelineImageInput = None,
729+
strength: float = 0.6,
730+
num_inference_steps: int = 50,
731+
sigmas: Optional[List[float]] = None,
732+
guidance_scale: float = 7.0,
733+
negative_prompt: Optional[Union[str, List[str]]] = None,
734+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
735+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
736+
num_images_per_prompt: Optional[int] = 1,
737+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
738+
latents: Optional[torch.FloatTensor] = None,
739+
prompt_embeds: Optional[torch.FloatTensor] = None,
740+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
741+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
742+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
743+
output_type: Optional[str] = "pil",
744+
return_dict: bool = True,
745+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
746+
clip_skip: Optional[int] = None,
747+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
748+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
749+
max_sequence_length: int = 256,
749750
):
750751
r"""
751752
Function invoked when calling the pipeline for generation.

0 commit comments

Comments
 (0)