2323from opendit .models .clip import TextEmbedder
2424from opendit .utils .operation import all_to_all_comm , gather_forward_split_backward
2525
26- ULYSSES = False
27- FLASH_ATTN = False
28- SP_SIZE = 2
29- LAYERNORM_KERNEL = False
30- MODULATE_KERNEL = False
31-
3226
3327def get_layernorm (hidden_size : torch .Tensor , eps : float , affine : bool , use_kernel : bool ):
3428 if use_kernel :
@@ -45,16 +39,17 @@ def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kerne
4539def modulate (norm_func , x , shift , scale , use_kernel = False ):
4640 # Suppose x is (N, T, D), shift is (N, D), scale is (N, D)
4741 dtype = x .dtype
48- x = norm_func (x .to (torch .float32 ))
42+ x , shift , scale = x .to (torch .float32 ), shift .to (torch .float32 ), scale .to (torch .float32 )
43+ x = norm_func (x )
4944 if use_kernel :
5045 try :
5146 from opendit .kernels .fused_modulate import fused_modulate
5247
53- x = fused_modulate (x , scale . to ( torch . float32 ) , shift . to ( torch . float32 ) )
48+ x = fused_modulate (x , scale , shift )
5449 except ImportError :
5550 raise RuntimeError ("FusedModulate kernel not available. Please install triton." )
5651 else :
57- x = x * (scale .to ( torch . float32 ). unsqueeze (1 ) + 1 ) + shift . to ( torch . float32 ) .unsqueeze (1 )
52+ x = x * (scale .unsqueeze (1 ) + 1 ) + shift .unsqueeze (1 )
5853 x = x .to (dtype )
5954
6055 return x
@@ -156,8 +151,8 @@ def __init__(
156151 attn_drop : float = 0.0 ,
157152 proj_drop : float = 0.0 ,
158153 norm_layer : nn .Module = nn .LayerNorm ,
159- use_flash_attn : bool = False ,
160- enable_sequence_parallelism : bool = False ,
154+ enable_flashattn : bool = False ,
155+ sequence_parallel_size : int = 1 ,
161156 ) -> None :
162157 super ().__init__ ()
163158 assert dim % num_heads == 0 , "dim should be divisible by num_heads"
@@ -172,16 +167,20 @@ def __init__(
172167 self .attn_drop = nn .Dropout (attn_drop )
173168 self .proj = nn .Linear (dim , dim )
174169 self .proj_drop = nn .Dropout (proj_drop )
175- self .use_flash_attn = use_flash_attn
176- self .enable_sequence_parallelism = enable_sequence_parallelism
170+ self .enable_flashattn = enable_flashattn
171+ # TODO: support sequence_parallel_size > 2
172+ assert sequence_parallel_size in [1 , 2 ], "sequence_parallel_size is only supported for 1 or 2"
173+ self .sequence_parallel_size = sequence_parallel_size
177174
178175 def forward (self , x : torch .Tensor ) -> torch .Tensor :
179176 B , N , C = x .shape
180177 qkv = self .qkv (x ) # (B, N, C), N here is N_total // SP_SIZE
181178 # Todo: Change num_heads in somewhere else for a better code style
182- num_heads = self .num_heads if not self .enable_sequence_parallelism else self .num_heads // SP_SIZE
179+ num_heads = (
180+ self .num_heads if self .sequence_parallel_size == 1 else self .num_heads // self .sequence_parallel_size
181+ )
183182
184- if self .enable_sequence_parallelism :
183+ if self .sequence_parallel_size > 1 :
185184 q , k , v = qkv .split (self .head_dim * self .num_heads , dim = - 1 )
186185 # q = q.reshape(1, -1, self.head_dim * self.num_heads)
187186 # k = k.reshape(1, -1, self.head_dim * self.num_heads)
@@ -191,9 +190,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
191190 k = all_to_all_comm (k , None )
192191 v = all_to_all_comm (v , None )
193192
194- q = q .reshape (B , N * SP_SIZE , num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 ).contiguous ()
195- k = k .reshape (B , N * SP_SIZE , num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 ).contiguous ()
196- v = v .reshape (B , N * SP_SIZE , num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 ).contiguous ()
193+ q = q .reshape (B , N * self . sequence_parallel_size , num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 ).contiguous ()
194+ k = k .reshape (B , N * self . sequence_parallel_size , num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 ).contiguous ()
195+ v = v .reshape (B , N * self . sequence_parallel_size , num_heads , self .head_dim ).permute (0 , 2 , 1 , 3 ).contiguous ()
197196
198197 else :
199198 # Todo: chunked flash attention
@@ -204,7 +203,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
204203 # .permute(2, 3, 0, 1, 4)
205204 # .reshape(3, B * num_heads, 1, N, self.head_dim)
206205 # )
207- if self .use_flash_attn :
206+ if self .enable_flashattn :
208207 # [3, B, num_heads, N, head_dim] => [B, N, num_heads, head_dim] * 3
209208 qkv = qkv .reshape (B , N , 3 , num_heads , self .head_dim ).permute (2 , 0 , 1 , 3 , 4 )
210209 else :
@@ -213,7 +212,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
213212 q , k , v = qkv .unbind (0 )
214213 q , k = self .q_norm (q ), self .k_norm (k )
215214
216- if self .use_flash_attn :
215+ if self .enable_flashattn :
217216 from flash_attn import flash_attn_func
218217
219218 # Todo: chunked flash attention
@@ -258,10 +257,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
258257 x = attn @ v
259258
260259 x_output_shape = (
261- (B , N , C ) if not self .enable_sequence_parallelism else (B , N * SP_SIZE , num_heads * self .head_dim )
260+ (B , N , C )
261+ if self .sequence_parallel_size == 1
262+ else (B , N * self .sequence_parallel_size , num_heads * self .head_dim )
262263 )
263264 x = x .transpose (1 , 2 ).reshape (x_output_shape )
264- if self .enable_sequence_parallelism :
265+ if self .sequence_parallel_size > 1 :
265266 # Todo: Use all_to_all_single for x
266267 # x = x.reshape(1, -1, num_heads * self.head_dim)
267268 x = all_to_all_comm (x , None , scatter_dim = 1 , gather_dim = 2 )
@@ -281,33 +282,37 @@ def __init__(
281282 hidden_size ,
282283 num_heads ,
283284 mlp_ratio = 4.0 ,
284- flash_attn = False ,
285- sequence_parallel = False ,
286- layernorm_kernel = False ,
287- modulate_kernel = False ,
285+ enable_flashattn = False ,
286+ sequence_parallel_size = False ,
287+ enable_layernorm_kernel = False ,
288+ enable_modulate_kernel = False ,
288289 ** block_kwargs ,
289290 ):
290291 super ().__init__ ()
291- self .modulate_kernel = modulate_kernel
292- self .norm1 = get_layernorm (hidden_size , eps = 1e-6 , affine = False , use_kernel = layernorm_kernel )
292+ self .enable_modulate_kernel = enable_modulate_kernel
293+ self .norm1 = get_layernorm (hidden_size , eps = 1e-6 , affine = False , use_kernel = enable_layernorm_kernel )
293294 self .attn = DistAttention (
294295 hidden_size ,
295296 num_heads = num_heads ,
296297 qkv_bias = True ,
297- use_flash_attn = flash_attn ,
298- enable_sequence_parallelism = sequence_parallel ,
298+ enable_flashattn = enable_flashattn ,
299+ sequence_parallel_size = sequence_parallel_size ,
299300 ** block_kwargs ,
300301 )
301- self .norm2 = get_layernorm (hidden_size , eps = 1e-6 , affine = False , use_kernel = layernorm_kernel )
302+ self .norm2 = get_layernorm (hidden_size , eps = 1e-6 , affine = False , use_kernel = enable_layernorm_kernel )
302303 mlp_hidden_dim = int (hidden_size * mlp_ratio )
303304 approx_gelu = lambda : nn .GELU (approximate = "tanh" )
304305 self .mlp = Mlp (in_features = hidden_size , hidden_features = mlp_hidden_dim , act_layer = approx_gelu , drop = 0 )
305306 self .adaLN_modulation = nn .Sequential (nn .SiLU (), nn .Linear (hidden_size , 6 * hidden_size , bias = True ))
306307
307308 def forward (self , x , c ):
308309 shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .adaLN_modulation (c ).chunk (6 , dim = 1 )
309- x = x + gate_msa .unsqueeze (1 ) * self .attn (modulate (self .norm1 , x , shift_msa , scale_msa , self .modulate_kernel ))
310- x = x + gate_mlp .unsqueeze (1 ) * self .mlp (modulate (self .norm2 , x , shift_mlp , scale_mlp , self .modulate_kernel ))
310+ x = x + gate_msa .unsqueeze (1 ) * self .attn (
311+ modulate (self .norm1 , x , shift_msa , scale_msa , self .enable_modulate_kernel )
312+ )
313+ x = x + gate_mlp .unsqueeze (1 ) * self .mlp (
314+ modulate (self .norm2 , x , shift_mlp , scale_mlp , self .enable_modulate_kernel )
315+ )
311316 return x
312317
313318
@@ -347,17 +352,18 @@ def __init__(
347352 class_dropout_prob = 0.1 ,
348353 num_classes = 1000 ,
349354 learn_sigma = True ,
350- flash_attn = FLASH_ATTN ,
351- sequence_parallel = ULYSSES ,
352- layernorm_kernel = LAYERNORM_KERNEL ,
353- modulate_kernel = MODULATE_KERNEL ,
355+ enable_flashattn = False ,
356+ enable_layernorm_kernel = False ,
357+ enable_modulate_kernel = False ,
358+ sequence_parallel_size = 1 ,
354359 ):
355360 super ().__init__ ()
356361 self .learn_sigma = learn_sigma
357362 self .in_channels = in_channels
358363 self .out_channels = in_channels * 2 if learn_sigma else in_channels
359364 self .patch_size = patch_size
360365 self .num_heads = num_heads
366+ self .sequence_parallel_size = sequence_parallel_size
361367
362368 self .x_embedder = PatchEmbed (input_size , patch_size , in_channels , hidden_size , bias = True )
363369 self .t_embedder = TimestepEmbedder (hidden_size )
@@ -378,10 +384,10 @@ def __init__(
378384 hidden_size ,
379385 num_heads ,
380386 mlp_ratio = mlp_ratio ,
381- flash_attn = flash_attn ,
382- sequence_parallel = sequence_parallel ,
383- modulate_kernel = modulate_kernel ,
384- layernorm_kernel = layernorm_kernel ,
387+ enable_flashattn = enable_flashattn ,
388+ sequence_parallel_size = sequence_parallel_size ,
389+ enable_modulate_kernel = enable_modulate_kernel ,
390+ enable_layernorm_kernel = enable_layernorm_kernel ,
385391 )
386392 for _ in range (depth )
387393 ]
@@ -471,16 +477,16 @@ def forward(self, x, t, y):
471477 c = t + y # (N, D)
472478
473479 # Chunk x on sequence dimension to sp group
474- if ULYSSES :
475- x = x .chunk (SP_SIZE , dim = 1 )[dist .get_rank ()]
480+ if self . sequence_parallel_size > 1 :
481+ x = x .chunk (self . sequence_parallel_size , dim = 1 )[dist .get_rank ()]
476482
477483 for block in self .blocks :
478484 if self .gradient_checkpointing :
479485 x = torch .utils .checkpoint .checkpoint (self .create_custom_forward (block ), x , c )
480486 else :
481487 x = block (x , c ) # (N, T, D)
482488
483- if ULYSSES :
489+ if self . sequence_parallel_size > 1 :
484490 x = gather_forward_split_backward (x , dim = 1 , process_group = None )
485491
486492 x = self .final_layer (x , c ) # (N, T, patch_size ** 2 * out_channels)
0 commit comments