Skip to content

Commit 9e0ca0b

Browse files
committed
apply review suggestions
1 parent eee08ad commit 9e0ca0b

File tree

4 files changed

+17
-35
lines changed

4 files changed

+17
-35
lines changed

src/diffusers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@
187187
"EDMEulerScheduler",
188188
"EulerAncestralDiscreteScheduler",
189189
"EulerDiscreteScheduler",
190-
"FlowDPMSolverMultistepScheduler",
191190
"FlowMatchEulerDiscreteScheduler",
192191
"FlowMatchHeunDiscreteScheduler",
193192
"HeunDiscreteScheduler",
@@ -692,7 +691,6 @@
692691
EDMEulerScheduler,
693692
EulerAncestralDiscreteScheduler,
694693
EulerDiscreteScheduler,
695-
FlowDPMSolverMultistepScheduler,
696694
FlowMatchEulerDiscreteScheduler,
697695
FlowMatchHeunDiscreteScheduler,
698696
HeunDiscreteScheduler,

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5446,11 +5446,6 @@ class SanaLinearAttnProcessor2_0:
54465446
Processor for implementing scaled dot-product linear attention.
54475447
"""
54485448

5449-
def __init__(self, pad_val=1.0, eps=1e-15):
5450-
self.pad_val = pad_val
5451-
self.eps = eps
5452-
self.kernel_func = nn.ReLU(inplace=False)
5453-
54545449
def __call__(
54555450
self,
54565451
attn: Attention,
@@ -5471,16 +5466,16 @@ def __call__(
54715466
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
54725467
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
54735468

5474-
query = self.kernel_func(query)
5475-
key = self.kernel_func(key)
5469+
query = F.relu(query)
5470+
key = F.relu(key)
54765471

54775472
query, key, value = query.float(), key.float(), value.float()
54785473

5479-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
5474+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
54805475
scores = torch.matmul(value, key)
54815476
hidden_states = torch.matmul(scores, query)
54825477

5483-
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
5478+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
54845479
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
54855480
hidden_states = hidden_states.to(original_dtype)
54865481

@@ -5498,11 +5493,6 @@ class PAGCFGSanaLinearAttnProcessor2_0:
54985493
Processor for implementing scaled dot-product linear attention.
54995494
"""
55005495

5501-
def __init__(self, pad_val=1.0, eps=1e-15):
5502-
self.pad_val = pad_val
5503-
self.eps = eps
5504-
self.kernel_func = nn.ReLU(inplace=False)
5505-
55065496
def __call__(
55075497
self,
55085498
attn: Attention,
@@ -5523,16 +5513,16 @@ def __call__(
55235513
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
55245514
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
55255515

5526-
query = self.kernel_func(query)
5527-
key = self.kernel_func(key)
5516+
query = F.relu(query)
5517+
key = F.relu(key)
55285518

55295519
query, key, value = query.float(), key.float(), value.float()
55305520

5531-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
5521+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
55325522
scores = torch.matmul(value, key)
55335523
hidden_states_org = torch.matmul(scores, query)
55345524

5535-
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + self.eps)
5525+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
55365526
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
55375527
hidden_states_org = hidden_states_org.to(original_dtype)
55385528

@@ -5558,11 +5548,6 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
55585548
Processor for implementing scaled dot-product linear attention.
55595549
"""
55605550

5561-
def __init__(self, pad_val=1.0, eps=1e-15):
5562-
self.pad_val = pad_val
5563-
self.eps = eps
5564-
self.kernel_func = nn.ReLU(inplace=False)
5565-
55665551
def __call__(
55675552
self,
55685553
attn: Attention,
@@ -5587,14 +5572,14 @@ def __call__(
55875572

55885573
query, key, value = query.float(), key.float(), value.float()
55895574

5590-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
5575+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
55915576
scores = torch.matmul(value, key)
55925577
hidden_states_org = torch.matmul(scores, query)
55935578

55945579
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
55955580
hidden_states_org = hidden_states_org.float()
55965581

5597-
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + self.eps)
5582+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
55985583
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
55995584
hidden_states_org = hidden_states_org.to(original_dtype)
56005585

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...image_processor import PixArtImageProcessor
2626
from ...models import AutoencoderDC, SanaTransformer2DModel
2727
from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0
28-
from ...schedulers import FlowDPMSolverMultistepScheduler
28+
from ...schedulers import FlowMatchEulerDiscreteScheduler
2929
from ...utils import (
3030
BACKENDS_MAPPING,
3131
is_bs4_available,
@@ -140,7 +140,7 @@ def __init__(
140140
text_encoder: AutoModelForCausalLM,
141141
vae: AutoencoderDC,
142142
transformer: SanaTransformer2DModel,
143-
scheduler: FlowDPMSolverMultistepScheduler,
143+
scheduler: FlowMatchEulerDiscreteScheduler,
144144
pag_applied_layers: Union[str, List[str]] = "transformer_blocks.0",
145145
):
146146
super().__init__()
@@ -316,7 +316,7 @@ def check_inputs(
316316
prompt_attention_mask=None,
317317
negative_prompt_attention_mask=None,
318318
):
319-
if height % 8 != 0 or width % 8 != 0:
319+
if height % 32 != 0 or width % 32 != 0:
320320
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
321321

322322
if callback_on_step_end_tensor_inputs is not None and not all(

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PixArtImageProcessor
2626
from ...models import AutoencoderDC, SanaTransformer2DModel
27-
from ...schedulers import FlowDPMSolverMultistepScheduler
27+
from ...schedulers import FlowMatchEulerDiscreteScheduler
2828
from ...utils import (
2929
BACKENDS_MAPPING,
3030
is_bs4_available,
@@ -137,7 +137,7 @@ def __init__(
137137
text_encoder: AutoModelForCausalLM,
138138
vae: AutoencoderDC,
139139
transformer: SanaTransformer2DModel,
140-
scheduler: FlowDPMSolverMultistepScheduler,
140+
scheduler: FlowMatchEulerDiscreteScheduler,
141141
):
142142
super().__init__()
143143

@@ -187,8 +187,7 @@ def encode_prompt(
187187
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
188188
provided, text embeddings will be generated from `prompt` input argument.
189189
negative_prompt_embeds (`torch.Tensor`, *optional*):
190-
Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the ""
191-
string.
190+
Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
192191
clean_caption (`bool`, defaults to `False`):
193192
If `True`, the function will preprocess and clean the provided caption before encoding.
194193
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
@@ -325,7 +324,7 @@ def check_inputs(
325324
prompt_attention_mask=None,
326325
negative_prompt_attention_mask=None,
327326
):
328-
if height % 8 != 0 or width % 8 != 0:
327+
if height % 32 != 0 or width % 32 != 0:
329328
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
330329

331330
if callback_on_step_end_tensor_inputs is not None and not all(

0 commit comments

Comments
 (0)