Skip to content

Commit ac4a132

Browse files
Apply style fixes
1 parent 2173054 commit ac4a132

File tree

2 files changed

+101
-104
lines changed

2 files changed

+101
-104
lines changed

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,12 @@
290290
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
291291
_import_structure["pia"] = ["PIAPipeline"]
292292
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
293-
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline", "SanaControlNetPipeline", "SanaSprintImg2ImgPipeline"]
293+
_import_structure["sana"] = [
294+
"SanaPipeline",
295+
"SanaSprintPipeline",
296+
"SanaControlNetPipeline",
297+
"SanaSprintImg2ImgPipeline",
298+
]
294299
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
295300
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
296301
_import_structure["stable_audio"] = [
@@ -675,7 +680,7 @@
675680
from .paint_by_example import PaintByExamplePipeline
676681
from .pia import PIAPipeline
677682
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
678-
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintPipeline, SanaSprintImg2ImgPipeline
683+
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
679684
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
680685
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
681686
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py

Lines changed: 94 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26-
from ...image_processor import PixArtImageProcessor, PipelineImageInput
26+
from ...image_processor import PipelineImageInput, PixArtImageProcessor
2727
from ...loaders import SanaLoraLoaderMixin
2828
from ...models import AutoencoderDC, SanaTransformer2DModel
2929
from ...schedulers import DPMSolverMultistepScheduler
@@ -43,6 +43,7 @@
4343
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
4444
from .pipeline_output import SanaPipelineOutput
4545

46+
4647
if is_torch_xla_available():
4748
import torch_xla.core.xla_model as xm
4849

@@ -77,12 +78,12 @@
7778

7879
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7980
def retrieve_timesteps(
80-
scheduler,
81-
num_inference_steps: Optional[int] = None,
82-
device: Optional[Union[str, torch.device]] = None,
83-
timesteps: Optional[List[int]] = None,
84-
sigmas: Optional[List[float]] = None,
85-
**kwargs,
81+
scheduler,
82+
num_inference_steps: Optional[int] = None,
83+
device: Optional[Union[str, torch.device]] = None,
84+
timesteps: Optional[List[int]] = None,
85+
sigmas: Optional[List[float]] = None,
86+
**kwargs,
8687
):
8788
r"""
8889
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -149,12 +150,12 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
149150
_callback_tensor_inputs = ["latents", "prompt_embeds"]
150151

151152
def __init__(
152-
self,
153-
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
154-
text_encoder: Gemma2PreTrainedModel,
155-
vae: AutoencoderDC,
156-
transformer: SanaTransformer2DModel,
157-
scheduler: DPMSolverMultistepScheduler,
153+
self,
154+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
155+
text_encoder: Gemma2PreTrainedModel,
156+
vae: AutoencoderDC,
157+
transformer: SanaTransformer2DModel,
158+
scheduler: DPMSolverMultistepScheduler,
158159
):
159160
super().__init__()
160161

@@ -200,13 +201,13 @@ def disable_vae_tiling(self):
200201

201202
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
202203
def _get_gemma_prompt_embeds(
203-
self,
204-
prompt: Union[str, List[str]],
205-
device: torch.device,
206-
dtype: torch.dtype,
207-
clean_caption: bool = False,
208-
max_sequence_length: int = 300,
209-
complex_human_instruction: Optional[List[str]] = None,
204+
self,
205+
prompt: Union[str, List[str]],
206+
device: torch.device,
207+
dtype: torch.dtype,
208+
clean_caption: bool = False,
209+
max_sequence_length: int = 300,
210+
complex_human_instruction: Optional[List[str]] = None,
210211
):
211212
r"""
212213
Encodes the prompt into text encoder hidden states.
@@ -258,16 +259,16 @@ def _get_gemma_prompt_embeds(
258259
return prompt_embeds, prompt_attention_mask
259260

260261
def encode_prompt(
261-
self,
262-
prompt: Union[str, List[str]],
263-
num_images_per_prompt: int = 1,
264-
device: Optional[torch.device] = None,
265-
prompt_embeds: Optional[torch.Tensor] = None,
266-
prompt_attention_mask: Optional[torch.Tensor] = None,
267-
clean_caption: bool = False,
268-
max_sequence_length: int = 300,
269-
complex_human_instruction: Optional[List[str]] = None,
270-
lora_scale: Optional[float] = None,
262+
self,
263+
prompt: Union[str, List[str]],
264+
num_images_per_prompt: int = 1,
265+
device: Optional[torch.device] = None,
266+
prompt_embeds: Optional[torch.Tensor] = None,
267+
prompt_attention_mask: Optional[torch.Tensor] = None,
268+
clean_caption: bool = False,
269+
max_sequence_length: int = 300,
270+
complex_human_instruction: Optional[List[str]] = None,
271+
lora_scale: Optional[float] = None,
271272
):
272273
r"""
273274
Encodes the prompt into text encoder hidden states.
@@ -366,25 +367,25 @@ def get_timesteps(self, num_inference_steps, strength, device):
366367
init_timestep = min(num_inference_steps * strength, num_inference_steps)
367368

368369
t_start = int(max(num_inference_steps - init_timestep, 0))
369-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
370+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
370371
if hasattr(self.scheduler, "set_begin_index"):
371372
self.scheduler.set_begin_index(t_start * self.scheduler.order)
372373

373374
return timesteps, num_inference_steps - t_start
374375

375376
def check_inputs(
376-
self,
377-
prompt,
378-
strength,
379-
height,
380-
width,
381-
num_inference_steps,
382-
timesteps,
383-
max_timesteps,
384-
intermediate_timesteps,
385-
callback_on_step_end_tensor_inputs=None,
386-
prompt_embeds=None,
387-
prompt_attention_mask=None,
377+
self,
378+
prompt,
379+
strength,
380+
height,
381+
width,
382+
num_inference_steps,
383+
timesteps,
384+
max_timesteps,
385+
intermediate_timesteps,
386+
callback_on_step_end_tensor_inputs=None,
387+
prompt_embeds=None,
388+
prompt_attention_mask=None,
388389
):
389390
if strength < 0 or strength > 1:
390391
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -393,7 +394,7 @@ def check_inputs(
393394
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
394395

395396
if callback_on_step_end_tensor_inputs is not None and not all(
396-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
397+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
397398
):
398399
raise ValueError(
399400
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -571,12 +572,12 @@ def _clean_caption(self, caption):
571572

572573
# Copied from diffusers.pipelines.sana.pipeline_sana_controlnet.SanaPipeline.prepare_latents
573574
def prepare_image(
574-
self,
575-
image,
576-
width,
577-
height,
578-
device,
579-
dtype,
575+
self,
576+
image,
577+
width,
578+
height,
579+
device,
580+
dtype,
580581
):
581582
if isinstance(image, torch.Tensor):
582583
pass
@@ -588,17 +589,9 @@ def prepare_image(
588589
return image
589590

590591
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
591-
def prepare_latents(self,
592-
image,
593-
timestep,
594-
batch_size,
595-
num_channels_latents,
596-
height,
597-
width,
598-
dtype,
599-
device,
600-
generator,
601-
latents=None):
592+
def prepare_latents(
593+
self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None
594+
):
602595
if latents is not None:
603596
return latents.to(device=device, dtype=dtype)
604597

@@ -609,7 +602,6 @@ def prepare_latents(self,
609602
int(width) // self.vae_scale_factor,
610603
)
611604

612-
613605
if image.shape[1] != num_channels_latents:
614606
image = self.vae.encode(image).latent
615607
image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data
@@ -657,41 +649,41 @@ def interrupt(self):
657649
@torch.no_grad()
658650
@replace_example_docstring(EXAMPLE_DOC_STRING)
659651
def __call__(
660-
self,
661-
prompt: Union[str, List[str]] = None,
662-
num_inference_steps: int = 2,
663-
timesteps: List[int] = None,
664-
max_timesteps: float = 1.57080,
665-
intermediate_timesteps: float = 1.3,
666-
guidance_scale: float = 4.5,
667-
image: PipelineImageInput = None,
668-
strength: float = 0.6,
669-
num_images_per_prompt: Optional[int] = 1,
670-
height: int = 1024,
671-
width: int = 1024,
672-
eta: float = 0.0,
673-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
674-
latents: Optional[torch.Tensor] = None,
675-
prompt_embeds: Optional[torch.Tensor] = None,
676-
prompt_attention_mask: Optional[torch.Tensor] = None,
677-
output_type: Optional[str] = "pil",
678-
return_dict: bool = True,
679-
clean_caption: bool = False,
680-
use_resolution_binning: bool = True,
681-
attention_kwargs: Optional[Dict[str, Any]] = None,
682-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
683-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
684-
max_sequence_length: int = 300,
685-
complex_human_instruction: List[str] = [
686-
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
687-
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
688-
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
689-
"Here are examples of how to transform or refine prompts:",
690-
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
691-
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
692-
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
693-
"User Prompt: ",
694-
],
652+
self,
653+
prompt: Union[str, List[str]] = None,
654+
num_inference_steps: int = 2,
655+
timesteps: List[int] = None,
656+
max_timesteps: float = 1.57080,
657+
intermediate_timesteps: float = 1.3,
658+
guidance_scale: float = 4.5,
659+
image: PipelineImageInput = None,
660+
strength: float = 0.6,
661+
num_images_per_prompt: Optional[int] = 1,
662+
height: int = 1024,
663+
width: int = 1024,
664+
eta: float = 0.0,
665+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
666+
latents: Optional[torch.Tensor] = None,
667+
prompt_embeds: Optional[torch.Tensor] = None,
668+
prompt_attention_mask: Optional[torch.Tensor] = None,
669+
output_type: Optional[str] = "pil",
670+
return_dict: bool = True,
671+
clean_caption: bool = False,
672+
use_resolution_binning: bool = True,
673+
attention_kwargs: Optional[Dict[str, Any]] = None,
674+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
675+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
676+
max_sequence_length: int = 300,
677+
complex_human_instruction: List[str] = [
678+
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
679+
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
680+
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
681+
"Here are examples of how to transform or refine prompts:",
682+
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
683+
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
684+
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
685+
"User Prompt: ",
686+
],
695687
) -> Union[SanaPipelineOutput, Tuple]:
696688
"""
697689
Function invoked when calling the pipeline for generation.
@@ -874,7 +866,7 @@ def __call__(
874866
)
875867

876868
# I think this is redundant given the scaling in prepare_latents
877-
#latents = latents * self.scheduler.config.sigma_data
869+
# latents = latents * self.scheduler.config.sigma_data
878870

879871
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
880872
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
@@ -902,7 +894,7 @@ def __call__(
902894

903895
scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
904896
latent_model_input = latents_model_input * torch.sqrt(
905-
scm_timestep_expanded ** 2 + (1 - scm_timestep_expanded) ** 2
897+
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
906898
)
907899

908900
# predict noise model_output
@@ -917,9 +909,9 @@ def __call__(
917909
)[0]
918910

919911
noise_pred = (
920-
(1 - 2 * scm_timestep_expanded) * latent_model_input
921-
+ (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded ** 2) * noise_pred
922-
) / torch.sqrt(scm_timestep_expanded ** 2 + (1 - scm_timestep_expanded) ** 2)
912+
(1 - 2 * scm_timestep_expanded) * latent_model_input
913+
+ (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
914+
) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
923915
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
924916

925917
# compute previous image: x_t -> x_t-1

0 commit comments

Comments
 (0)