Skip to content

Commit 0debade

Browse files
committed
make fix-copies
1 parent 9d31426 commit 0debade

File tree

2 files changed

+5
-17
lines changed

2 files changed

+5
-17
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def __init__(
180180
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
181181
)
182182

183-
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
183+
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt
184184
def encode_prompt(
185185
self,
186186
prompt: Union[str, List[str]],
@@ -194,7 +194,6 @@ def encode_prompt(
194194
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
195195
clean_caption: bool = False,
196196
max_sequence_length: int = 300,
197-
**kwargs,
198197
):
199198
r"""
200199
Encodes the prompt into text encoder hidden states.
@@ -223,10 +222,6 @@ def encode_prompt(
223222
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
224223
"""
225224

226-
if "mask_feature" in kwargs:
227-
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
228-
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
229-
230225
if device is None:
231226
device = self._execution_device
232227

@@ -251,16 +246,6 @@ def encode_prompt(
251246
return_tensors="pt",
252247
)
253248
text_input_ids = text_inputs.input_ids
254-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
255-
256-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
257-
text_input_ids, untruncated_ids
258-
):
259-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
260-
logger.warning(
261-
"The following part of your input was truncated because T5 can only handle sequences up to"
262-
f" {max_length} tokens: {removed_text}"
263-
)
264249

265250
prompt_attention_mask = text_inputs.attention_mask
266251
prompt_attention_mask = prompt_attention_mask.to(device)
@@ -568,6 +553,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
568553
else:
569554
latents = latents.to(device)
570555

556+
# scale the initial noise by the standard deviation required by the scheduler
557+
latents = latents * self.scheduler.init_noise_sigma
571558
return latents
572559

573560
@torch.no_grad()

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def prepare_extra_step_kwargs(self, generator, eta):
319319
extra_step_kwargs["generator"] = generator
320320
return extra_step_kwargs
321321

322-
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
323322
def check_inputs(
324323
self,
325324
prompt,
@@ -537,6 +536,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
537536
else:
538537
latents = latents.to(device)
539538

539+
# scale the initial noise by the standard deviation required by the scheduler
540+
latents = latents * self.scheduler.init_noise_sigma
540541
return latents
541542

542543
@torch.no_grad()

0 commit comments

Comments
 (0)