Skip to content

Commit 5fb973c

Browse files
committed
make fix-copies
1 parent b7837c0 commit 5fb973c

9 files changed

+153
-19
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def encode_prompt(
171171
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
172172
clean_caption: bool = False,
173173
max_sequence_length: int = 300,
174+
complex_human_instruction: Optional[List[str]] = None,
174175
):
175176
r"""
176177
Encodes the prompt into text encoder hidden states.
@@ -192,11 +193,13 @@ def encode_prompt(
192193
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
193194
provided, text embeddings will be generated from `prompt` input argument.
194195
negative_prompt_embeds (`torch.Tensor`, *optional*):
195-
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
196-
string.
196+
Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
197197
clean_caption (`bool`, defaults to `False`):
198198
If `True`, the function will preprocess and clean the provided caption before encoding.
199199
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
200+
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
201+
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
202+
the prompt.
200203
"""
201204

202205
if device is None:
@@ -209,15 +212,28 @@ def encode_prompt(
209212
else:
210213
batch_size = prompt_embeds.shape[0]
211214

215+
self.tokenizer.padding_side = "right"
216+
212217
# See Section 3.1. of the paper.
213218
max_length = max_sequence_length
219+
select_index = [0] + list(range(-max_length + 1, 0))
214220

215221
if prompt_embeds is None:
216222
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
223+
224+
# prepare complex human instruction
225+
if not complex_human_instruction:
226+
max_length_all = max_length
227+
else:
228+
chi_prompt = "\n".join(complex_human_instruction)
229+
prompt = [chi_prompt + p for p in prompt]
230+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
231+
max_length_all = num_chi_prompt_tokens + max_length - 2
232+
217233
text_inputs = self.tokenizer(
218234
prompt,
219235
padding="max_length",
220-
max_length=max_length,
236+
max_length=max_length_all,
221237
truncation=True,
222238
add_special_tokens=True,
223239
return_tensors="pt",
@@ -228,7 +244,8 @@ def encode_prompt(
228244
prompt_attention_mask = prompt_attention_mask.to(device)
229245

230246
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
231-
prompt_embeds = prompt_embeds[0]
247+
prompt_embeds = prompt_embeds[0][:, select_index]
248+
prompt_attention_mask = prompt_attention_mask[:, select_index]
232249

233250
if self.transformer is not None:
234251
dtype = self.transformer.dtype
@@ -317,7 +334,7 @@ def check_inputs(
317334
negative_prompt_attention_mask=None,
318335
):
319336
if height % 32 != 0 or width % 32 != 0:
320-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
337+
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
321338

322339
if callback_on_step_end_tensor_inputs is not None and not all(
323340
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -573,6 +590,16 @@ def __call__(
573590
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
574591
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
575592
max_sequence_length: int = 300,
593+
complex_human_instruction: List[str] = [
594+
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
595+
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
596+
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
597+
"Here are examples of how to transform or refine prompts:",
598+
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
599+
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
600+
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
601+
"User Prompt: ",
602+
],
576603
pag_scale: float = 3.0,
577604
pag_adaptive_scale: float = 0.0,
578605
) -> Union[ImagePipelineOutput, Tuple]:
@@ -652,6 +679,9 @@ def __call__(
652679
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
653680
`._callback_tensor_inputs` attribute of your pipeline class.
654681
max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`.
682+
complex_human_instruction (`List[str]`, *optional*):
683+
Instructions for complex human attention:
684+
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
655685
pag_scale (`float`, *optional*, defaults to 3.0):
656686
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
657687
guidance will not be used.
@@ -727,6 +757,7 @@ def __call__(
727757
negative_prompt_attention_mask=negative_prompt_attention_mask,
728758
clean_caption=clean_caption,
729759
max_sequence_length=max_sequence_length,
760+
complex_human_instruction=complex_human_instruction,
730761
)
731762

732763
if self.do_perturbed_attention_guidance:

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def __call__(
582582
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
583583
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
584584
max_sequence_length: int = 300,
585-
complex_human_instruction: list[str] = [
585+
complex_human_instruction: List[str] = [
586586
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
587587
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
588588
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
@@ -670,7 +670,7 @@ def __call__(
670670
`._callback_tensor_inputs` attribute of your pipeline class.
671671
max_sequence_length (`int` defaults to `300`):
672672
Maximum sequence length to use with the `prompt`.
673-
complex_human_instruction (`list[str]`, *optional*):
673+
complex_human_instruction (`List[str]`, *optional*):
674674
Instructions for complex human attention:
675675
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
676676

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def __init__(
149149
use_karras_sigmas: Optional[bool] = False,
150150
use_exponential_sigmas: Optional[bool] = False,
151151
use_beta_sigmas: Optional[bool] = False,
152+
use_flow_sigmas: Optional[bool] = False,
153+
flow_shift: Optional[float] = 1.0,
152154
timestep_spacing: str = "linspace",
153155
steps_offset: int = 0,
154156
):
@@ -282,6 +284,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
282284
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
283285
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
284286
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
287+
elif self.config.use_flow_sigmas:
288+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
289+
sigmas = 1.0 - alphas
290+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
291+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
285292
else:
286293
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
287294
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -362,8 +369,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
362369

363370
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
364371
def _sigma_to_alpha_sigma_t(self, sigma):
365-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
366-
sigma_t = sigma * alpha_t
372+
if self.config.use_flow_sigmas:
373+
alpha_t = 1 - sigma
374+
sigma_t = sigma
375+
else:
376+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
377+
sigma_t = sigma * alpha_t
367378

368379
return alpha_t, sigma_t
369380

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def __init__(
169169
use_karras_sigmas: Optional[bool] = False,
170170
use_exponential_sigmas: Optional[bool] = False,
171171
use_beta_sigmas: Optional[bool] = False,
172+
use_flow_sigmas: Optional[bool] = False,
173+
flow_shift: Optional[float] = 1.0,
172174
lambda_min_clipped: float = -float("inf"),
173175
variance_type: Optional[str] = None,
174176
timestep_spacing: str = "linspace",
@@ -292,6 +294,11 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
292294
elif self.config.use_beta_sigmas:
293295
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
294296
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
297+
elif self.config.use_flow_sigmas:
298+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
299+
sigmas = 1.0 - alphas
300+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
301+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
295302
else:
296303
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
297304
sigma_max = (
@@ -379,8 +386,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
379386

380387
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
381388
def _sigma_to_alpha_sigma_t(self, sigma):
382-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
383-
sigma_t = sigma * alpha_t
389+
if self.config.use_flow_sigmas:
390+
alpha_t = 1 - sigma
391+
sigma_t = sigma
392+
else:
393+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
394+
sigma_t = sigma * alpha_t
384395

385396
return alpha_t, sigma_t
386397

@@ -522,10 +533,13 @@ def convert_model_output(
522533
sigma = self.sigmas[self.step_index]
523534
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
524535
x0_pred = alpha_t * sample - sigma_t * model_output
536+
elif self.config.prediction_type == "flow_prediction":
537+
sigma_t = self.sigmas[self.step_index]
538+
x0_pred = sample - sigma_t * model_output
525539
else:
526540
raise ValueError(
527-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
528-
" `v_prediction` for the DPMSolverMultistepScheduler."
541+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
542+
" `v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
529543
)
530544

531545
if self.config.thresholding:

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def __init__(
164164
use_karras_sigmas: Optional[bool] = False,
165165
use_exponential_sigmas: Optional[bool] = False,
166166
use_beta_sigmas: Optional[bool] = False,
167+
use_flow_sigmas: Optional[bool] = False,
168+
flow_shift: Optional[float] = 1.0,
167169
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
168170
lambda_min_clipped: float = -float("inf"),
169171
variance_type: Optional[str] = None,
@@ -356,6 +358,11 @@ def set_timesteps(
356358
sigmas = np.flip(sigmas).copy()
357359
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
358360
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
361+
elif self.config.use_flow_sigmas:
362+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
363+
sigmas = 1.0 - alphas
364+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
365+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
359366
else:
360367
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
361368

@@ -454,8 +461,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
454461

455462
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
456463
def _sigma_to_alpha_sigma_t(self, sigma):
457-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
458-
sigma_t = sigma * alpha_t
464+
if self.config.use_flow_sigmas:
465+
alpha_t = 1 - sigma
466+
sigma_t = sigma
467+
else:
468+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
469+
sigma_t = sigma * alpha_t
459470

460471
return alpha_t, sigma_t
461472

src/diffusers/schedulers/scheduling_sasolver.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def __init__(
167167
use_karras_sigmas: Optional[bool] = False,
168168
use_exponential_sigmas: Optional[bool] = False,
169169
use_beta_sigmas: Optional[bool] = False,
170+
use_flow_sigmas: Optional[bool] = False,
171+
flow_shift: Optional[float] = 1.0,
170172
lambda_min_clipped: float = -float("inf"),
171173
variance_type: Optional[str] = None,
172174
timestep_spacing: str = "linspace",
@@ -311,6 +313,11 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
311313
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
312314
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
313315
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
316+
elif self.config.use_flow_sigmas:
317+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
318+
sigmas = 1.0 - alphas
319+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
320+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
314321
else:
315322
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
316323
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -391,8 +398,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
391398

392399
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
393400
def _sigma_to_alpha_sigma_t(self, sigma):
394-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
395-
sigma_t = sigma * alpha_t
401+
if self.config.use_flow_sigmas:
402+
alpha_t = 1 - sigma
403+
sigma_t = sigma
404+
else:
405+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
406+
sigma_t = sigma * alpha_t
396407

397408
return alpha_t, sigma_t
398409

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ def __init__(
206206
use_karras_sigmas: Optional[bool] = False,
207207
use_exponential_sigmas: Optional[bool] = False,
208208
use_beta_sigmas: Optional[bool] = False,
209+
use_flow_sigmas: Optional[bool] = False,
210+
flow_shift: Optional[float] = 1.0,
209211
timestep_spacing: str = "linspace",
210212
steps_offset: int = 0,
211213
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
@@ -374,6 +376,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
374376
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
375377
)
376378
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
379+
elif self.config.use_flow_sigmas:
380+
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
381+
sigmas = 1.0 - alphas
382+
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
383+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
377384
else:
378385
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
379386
if self.config.final_sigmas_type == "sigma_min":
@@ -464,8 +471,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
464471

465472
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
466473
def _sigma_to_alpha_sigma_t(self, sigma):
467-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
468-
sigma_t = sigma * alpha_t
474+
if self.config.use_flow_sigmas:
475+
alpha_t = 1 - sigma
476+
sigma_t = sigma
477+
else:
478+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
479+
sigma_t = sigma * alpha_t
469480

470481
return alpha_t, sigma_t
471482

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,21 @@ def from_pretrained(cls, *args, **kwargs):
557557
requires_backends(cls, ["torch"])
558558

559559

560+
class SanaTransformer2DModel(metaclass=DummyObject):
561+
_backends = ["torch"]
562+
563+
def __init__(self, *args, **kwargs):
564+
requires_backends(self, ["torch"])
565+
566+
@classmethod
567+
def from_config(cls, *args, **kwargs):
568+
requires_backends(cls, ["torch"])
569+
570+
@classmethod
571+
def from_pretrained(cls, *args, **kwargs):
572+
requires_backends(cls, ["torch"])
573+
574+
560575
class SD3ControlNetModel(metaclass=DummyObject):
561576
_backends = ["torch"]
562577

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,36 @@ def from_pretrained(cls, *args, **kwargs):
12621262
requires_backends(cls, ["torch", "transformers"])
12631263

12641264

1265+
class SanaPAGPipeline(metaclass=DummyObject):
1266+
_backends = ["torch", "transformers"]
1267+
1268+
def __init__(self, *args, **kwargs):
1269+
requires_backends(self, ["torch", "transformers"])
1270+
1271+
@classmethod
1272+
def from_config(cls, *args, **kwargs):
1273+
requires_backends(cls, ["torch", "transformers"])
1274+
1275+
@classmethod
1276+
def from_pretrained(cls, *args, **kwargs):
1277+
requires_backends(cls, ["torch", "transformers"])
1278+
1279+
1280+
class SanaPipeline(metaclass=DummyObject):
1281+
_backends = ["torch", "transformers"]
1282+
1283+
def __init__(self, *args, **kwargs):
1284+
requires_backends(self, ["torch", "transformers"])
1285+
1286+
@classmethod
1287+
def from_config(cls, *args, **kwargs):
1288+
requires_backends(cls, ["torch", "transformers"])
1289+
1290+
@classmethod
1291+
def from_pretrained(cls, *args, **kwargs):
1292+
requires_backends(cls, ["torch", "transformers"])
1293+
1294+
12651295
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
12661296
_backends = ["torch", "transformers"]
12671297

0 commit comments

Comments
 (0)