Skip to content

Commit 86bef58

Browse files
committed
modify the attention processor with set_attn_processor and change SanaAttnProcessor3_0 to SanaVanillaAttnProcessor
1 parent 9cb050b commit 86bef58

File tree

2 files changed

+84
-115
lines changed

2 files changed

+84
-115
lines changed

examples/research_projects/sana/train_sana_sprint_diffusers.py

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,86 @@
8686
"User Prompt: ",
8787
]
8888

89+
class SanaVanillaAttnProcessor:
90+
r"""
91+
Processor for implementing scaled dot-product attention to support JVP calculation during training.
92+
"""
93+
94+
def __init__(self):
95+
pass
96+
97+
@staticmethod
98+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
99+
) -> torch.Tensor:
100+
B, H, L, S = *query.size()[:-1], key.size(-2)
101+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
102+
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
103+
104+
if attn_mask is not None:
105+
if attn_mask.dtype == torch.bool:
106+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
107+
else:
108+
attn_bias += attn_mask
109+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
110+
attn_weight += attn_bias
111+
attn_weight = torch.softmax(attn_weight, dim=-1)
112+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
113+
return attn_weight @ value
114+
115+
def __call__(
116+
self,
117+
attn: Attention,
118+
hidden_states: torch.Tensor,
119+
encoder_hidden_states: Optional[torch.Tensor] = None,
120+
attention_mask: Optional[torch.Tensor] = None,
121+
) -> torch.Tensor:
122+
batch_size, sequence_length, _ = (
123+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
124+
)
125+
126+
if attention_mask is not None:
127+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
128+
# scaled_dot_product_attention expects attention_mask shape to be
129+
# (batch, heads, source_length, target_length)
130+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
131+
132+
query = attn.to_q(hidden_states)
133+
134+
if encoder_hidden_states is None:
135+
encoder_hidden_states = hidden_states
136+
137+
key = attn.to_k(encoder_hidden_states)
138+
value = attn.to_v(encoder_hidden_states)
139+
140+
if attn.norm_q is not None:
141+
query = attn.norm_q(query)
142+
if attn.norm_k is not None:
143+
key = attn.norm_k(key)
144+
145+
inner_dim = key.shape[-1]
146+
head_dim = inner_dim // attn.heads
147+
148+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
149+
150+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
151+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
152+
153+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
154+
hidden_states = self.scaled_dot_product_attention(
155+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
156+
)
157+
158+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
159+
hidden_states = hidden_states.to(query.dtype)
160+
161+
# linear proj
162+
hidden_states = attn.to_out[0](hidden_states)
163+
# dropout
164+
hidden_states = attn.to_out[1](hidden_states)
165+
166+
hidden_states = hidden_states / attn.rescale_output_factor
167+
168+
return hidden_states
89169

90170

91171
class Text2ImageDataset(Dataset):
@@ -109,7 +189,6 @@ def __init__(self, hf_dataset, resolution=1024):
109189
T.Lambda(lambda img: img.convert("RGB")),
110190
T.Resize(resolution), # Image.BICUBIC
111191
T.CenterCrop(resolution),
112-
# T.RandomHorizontalFlip(),
113192
T.ToTensor(),
114193
T.Normalize([0.5], [0.5]),
115194
])
@@ -132,7 +211,7 @@ def __getitem__(self, idx):
132211
'image': image_tensor
133212
}
134213

135-
# TODO here
214+
136215
def save_model_card(
137216
repo_id: str,
138217
images=None,
@@ -807,7 +886,6 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
807886
return (trigflow_model_out,)
808887

809888

810-
811889
def compute_density_for_timestep_sampling_scm(
812890
batch_size: int, logit_mean: float = None, logit_std: float = None
813891
):
@@ -820,7 +898,6 @@ def compute_density_for_timestep_sampling_scm(
820898
return u
821899

822900

823-
824901
def main(args):
825902
if args.report_to == "wandb" and args.hub_token is not None:
826903
raise ValueError(
@@ -872,7 +949,6 @@ def main(args):
872949
if args.seed is not None:
873950
set_seed(args.seed)
874951

875-
876952
# Handle the repository creation
877953
if accelerator.is_main_process:
878954
if args.output_dir is not None:
@@ -904,8 +980,9 @@ def main(args):
904980

905981
ori_transformer = SanaTransformer2DModel.from_pretrained(
906982
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant,
907-
guidance_embeds=True, cross_attention_type='vanilla'
983+
guidance_embeds=True,
908984
)
985+
ori_transformer.set_attn_processor(SanaVanillaAttnProcessor())
909986

910987
ori_transformer_no_guide = SanaTransformer2DModel.from_pretrained(
911988
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant,
@@ -929,7 +1006,6 @@ def main(args):
9291006

9301007
zero_state_dict = {}
9311008

932-
9331009
target_device = accelerator.device
9341010
param_w1 = guidance_embedder_module.linear_1.weight
9351011
zero_state_dict['linear_1.weight'] = torch.zeros(param_w1.shape, device=target_device)
@@ -941,7 +1017,6 @@ def main(args):
9411017
zero_state_dict['linear_2.bias'] = torch.zeros(param_b2.shape, device=target_device)
9421018
guidance_embedder_module.load_state_dict(zero_state_dict, strict=False, assign=True)
9431019

944-
9451020
transformer = SanaTrigFlow(ori_transformer, guidance=True).train()
9461021
pretrained_model = SanaTrigFlow(ori_transformer_no_guide, guidance=False).eval()
9471022

@@ -951,7 +1026,6 @@ def main(args):
9511026
head_block_ids=args.head_block_ids,
9521027
).train()
9531028

954-
9551029
transformer.requires_grad_(True)
9561030
pretrained_model.requires_grad_(False)
9571031
disc.model.requires_grad_(False)
@@ -1005,7 +1079,6 @@ def main(args):
10051079
if args.gradient_checkpointing:
10061080
transformer.enable_gradient_checkpointing()
10071081

1008-
10091082
def unwrap_model(model):
10101083
model = accelerator.unwrap_model(model)
10111084
model = model._orig_mod if is_compiled_module(model) else model
@@ -1063,7 +1136,6 @@ def load_model_hook(models, input_dir):
10631136
accelerator.register_save_state_pre_hook(save_model_hook)
10641137
accelerator.register_load_state_pre_hook(load_model_hook)
10651138

1066-
10671139
# Enable TF32 for faster training on Ampere GPUs,
10681140
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
10691141
if args.allow_tf32 and torch.cuda.is_available():
@@ -1087,7 +1159,6 @@ def load_model_hook(models, input_dir):
10871159
else:
10881160
optimizer_class = torch.optim.AdamW
10891161

1090-
10911162
# Optimization parameters
10921163
optimizer_G = optimizer_class(
10931164
transformer.parameters(),
@@ -1391,12 +1462,10 @@ def model_wrapper(scaled_x_t, t):
13911462
z_D = torch.randn_like(model_input) * sigma_data
13921463
noised_predicted_x0 = torch.cos(t_D) * pred_x_0 + torch.sin(t_D) * z_D
13931464

1394-
13951465
# Calculate adversarial loss
13961466
pred_fake = disc(hidden_states=(noised_predicted_x0 / sigma_data), timestep=t_D.flatten(), encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask)
13971467
adv_loss = -torch.mean(pred_fake)
13981468

1399-
14001469
# Total loss = sCM loss + LADD loss
14011470

14021471
total_loss = args.scm_lambda * loss + adv_loss * args.adv_lambda
@@ -1405,8 +1474,6 @@ def model_wrapper(scaled_x_t, t):
14051474

14061475
accelerator.backward(total_loss)
14071476

1408-
1409-
14101477
if accelerator.sync_gradients:
14111478
grad_norm = accelerator.clip_grad_norm_(transformer.parameters(), args.gradient_clip)
14121479
if torch.logical_or(grad_norm.isnan(), grad_norm.isinf()):
@@ -1504,7 +1571,6 @@ def model_wrapper(scaled_x_t, t):
15041571

15051572
accelerator.backward(loss_D)
15061573

1507-
15081574
if accelerator.sync_gradients:
15091575
grad_norm = accelerator.clip_grad_norm_(disc.parameters(), args.gradient_clip)
15101576
if torch.logical_or(grad_norm.isnan(), grad_norm.isinf()):
@@ -1519,7 +1585,6 @@ def model_wrapper(scaled_x_t, t):
15191585
optimizer_D.step()
15201586
optimizer_D.zero_grad(set_to_none=True)
15211587

1522-
15231588
# Checks if the accelerator has performed an optimization step behind the scenes
15241589
if accelerator.sync_gradients:
15251590
progress_bar.update(1)
@@ -1584,7 +1649,6 @@ def model_wrapper(scaled_x_t, t):
15841649
images = None
15851650
del pipeline
15861651

1587-
15881652
accelerator.wait_for_everyone()
15891653
if accelerator.is_main_process:
15901654
transformer = unwrap_model(transformer)

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
1615
from typing import Any, Dict, Optional, Tuple, Union
1716

1817
import torch
@@ -186,91 +185,6 @@ def __call__(
186185
return hidden_states
187186

188187

189-
class SanaAttnProcessor3_0:
190-
r"""
191-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192-
"""
193-
194-
def __init__(self):
195-
if not hasattr(F, "scaled_dot_product_attention"):
196-
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
197-
198-
@staticmethod
199-
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
200-
) -> torch.Tensor:
201-
B, H, L, S = *query.size()[:-1], key.size(-2)
202-
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
203-
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
204-
205-
if attn_mask is not None:
206-
if attn_mask.dtype == torch.bool:
207-
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
208-
else:
209-
attn_bias += attn_mask
210-
attn_weight = query @ key.transpose(-2, -1) * scale_factor
211-
attn_weight += attn_bias
212-
attn_weight = torch.softmax(attn_weight, dim=-1)
213-
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
214-
return attn_weight @ value
215-
216-
# return x
217-
def __call__(
218-
self,
219-
attn: Attention,
220-
hidden_states: torch.Tensor,
221-
encoder_hidden_states: Optional[torch.Tensor] = None,
222-
attention_mask: Optional[torch.Tensor] = None,
223-
) -> torch.Tensor:
224-
batch_size, sequence_length, _ = (
225-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
226-
)
227-
228-
if attention_mask is not None:
229-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
230-
# scaled_dot_product_attention expects attention_mask shape to be
231-
# (batch, heads, source_length, target_length)
232-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
233-
234-
query = attn.to_q(hidden_states)
235-
236-
if encoder_hidden_states is None:
237-
encoder_hidden_states = hidden_states
238-
239-
key = attn.to_k(encoder_hidden_states)
240-
value = attn.to_v(encoder_hidden_states)
241-
242-
if attn.norm_q is not None:
243-
query = attn.norm_q(query)
244-
if attn.norm_k is not None:
245-
key = attn.norm_k(key)
246-
247-
inner_dim = key.shape[-1]
248-
head_dim = inner_dim // attn.heads
249-
250-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251-
252-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
253-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
254-
255-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
256-
# TODO: add support for attn.scale when we move to Torch 2.1
257-
hidden_states = self.scaled_dot_product_attention(
258-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
259-
)
260-
261-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
262-
hidden_states = hidden_states.to(query.dtype)
263-
264-
# linear proj
265-
hidden_states = attn.to_out[0](hidden_states)
266-
# dropout
267-
hidden_states = attn.to_out[1](hidden_states)
268-
269-
hidden_states = hidden_states / attn.rescale_output_factor
270-
271-
return hidden_states
272-
273-
274188
class SanaTransformerBlock(nn.Module):
275189
r"""
276190
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -291,7 +205,6 @@ def __init__(
291205
attention_out_bias: bool = True,
292206
mlp_ratio: float = 2.5,
293207
qk_norm: Optional[str] = None,
294-
cross_attention_type: str = "flash",
295208
) -> None:
296209
super().__init__()
297210

@@ -310,12 +223,6 @@ def __init__(
310223
)
311224

312225
# 2. Cross Attention
313-
if cross_attention_type == "flash":
314-
cross_attention_processor = SanaAttnProcessor2_0()
315-
elif cross_attention_type == "vanilla":
316-
cross_attention_processor = SanaAttnProcessor3_0()
317-
else:
318-
raise ValueError(f"Cross attention type {cross_attention_type} is not defined.")
319226
if cross_attention_dim is not None:
320227
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
321228
self.attn2 = Attention(
@@ -328,7 +235,7 @@ def __init__(
328235
dropout=dropout,
329236
bias=True,
330237
out_bias=attention_out_bias,
331-
processor=cross_attention_processor,
238+
processor=SanaAttnProcessor2_0(),
332239
)
333240

334241
# 3. Feed-forward
@@ -453,7 +360,6 @@ def __init__(
453360
guidance_embeds_scale: float = 0.1,
454361
qk_norm: Optional[str] = None,
455362
timestep_scale: float = 1.0,
456-
cross_attention_type: str = "flash",
457363
) -> None:
458364
super().__init__()
459365

@@ -496,7 +402,6 @@ def __init__(
496402
norm_eps=norm_eps,
497403
mlp_ratio=mlp_ratio,
498404
qk_norm=qk_norm,
499-
cross_attention_type=cross_attention_type,
500405
)
501406
for _ in range(num_layers)
502407
]

0 commit comments

Comments
 (0)