1515import torch
1616import torch .distributed as dist
1717import torch .nn as nn
18- import torch .nn .functional as F
1918import torch .utils .checkpoint
20- from timm .models .vision_transformer import Mlp , PatchEmbed , use_fused_attn
19+ from timm .models .vision_transformer import Mlp , PatchEmbed
2120from torch .jit import Final
2221
2322from opendit .models .clip import TextEmbedder
@@ -158,7 +157,6 @@ def __init__(
158157 self .num_heads = num_heads
159158 self .head_dim = dim // num_heads
160159 self .scale = self .head_dim ** - 0.5
161- self .fused_attn = use_fused_attn ()
162160
163161 self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
164162 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
@@ -236,13 +234,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
236234 dropout_p = self .attn_drop .p if self .training else 0.0 ,
237235 softmax_scale = self .scale ,
238236 )
239- elif self .fused_attn :
240- x = F .scaled_dot_product_attention (
241- q ,
242- k ,
243- v ,
244- dropout_p = self .attn_drop .p if self .training else 0.0 ,
245- )
246237 else :
247238 dtype = q .dtype
248239 q = q * self .scale
@@ -260,7 +251,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
260251 if self .sequence_parallel_size == 1
261252 else (B , N * self .sequence_parallel_size , num_heads * self .head_dim )
262253 )
263- x = x .transpose (1 , 2 ).reshape (x_output_shape )
254+ if self .enable_flashattn :
255+ x = x .reshape (x_output_shape )
256+ else :
257+ x = x .transpose (1 , 2 ).reshape (x_output_shape )
258+
264259 if self .sequence_parallel_size > 1 :
265260 # Todo: Use all_to_all_single for x
266261 # x = x.reshape(1, -1, num_heads * self.head_dim)
@@ -355,6 +350,7 @@ def __init__(
355350 enable_layernorm_kernel = False ,
356351 enable_modulate_kernel = False ,
357352 sequence_parallel_size = 1 ,
353+ dtype = torch .float32 ,
358354 ):
359355 super ().__init__ ()
360356 self .learn_sigma = learn_sigma
@@ -363,6 +359,12 @@ def __init__(
363359 self .patch_size = patch_size
364360 self .num_heads = num_heads
365361 self .sequence_parallel_size = sequence_parallel_size
362+ self .dtype = dtype
363+ if enable_flashattn :
364+ assert dtype in [
365+ torch .float16 ,
366+ torch .bfloat16 ,
367+ ], f"Flash attention only supports float16 and bfloat16, but got { self .dtype } "
366368
367369 self .x_embedder = PatchEmbed (input_size , patch_size , in_channels , hidden_size , bias = True )
368370 self .t_embedder = TimestepEmbedder (hidden_size )
@@ -470,6 +472,10 @@ def forward(self, x, t, y):
470472
471473 # Todo: Mock video input by repeating the same frame for all timesteps
472474 # x = torch.randn(2, 256, 1152).to(torch.bfloat16).cuda()
475+
476+ # origin inputs should be float32, cast to specified dtype
477+ x = x .to (self .dtype )
478+
473479 x = self .x_embedder (x ) + self .pos_embed # (N, T, D), where T = H * W / patch_size ** 2
474480 t = self .t_embedder (t , dtype = x .dtype ) # (N, D)
475481 y = self .y_embedder (y , self .training ) # (N, D)
@@ -490,6 +496,9 @@ def forward(self, x, t, y):
490496
491497 x = self .final_layer (x , c ) # (N, T, patch_size ** 2 * out_channels)
492498 x = self .unpatchify (x ) # (N, out_channels, H, W)
499+
500+ # cast to float32 for better accuracy
501+ x = x .to (torch .float32 )
493502 return x
494503
495504 def forward_with_cfg (self , x , t , y , cfg_scale ):
0 commit comments