77import torch
88import torch .nn as nn
99
10- from fastvideo .v1 . attention import DistributedAttention , LocalAttention
11- from fastvideo .v1 . configs .models .dits .cosmos import CosmosConfig
12- from fastvideo .v1 . forward_context import get_forward_context
13- from fastvideo .v1 . layers .layernorm import RMSNorm
14- from fastvideo .v1 . layers .linear import ReplicatedLinear
15- from fastvideo .v1 . layers .mlp import MLP
16- from fastvideo .v1 . layers .rotary_embedding import apply_rotary_emb
17- from fastvideo .v1 . layers .visual_embedding import Timesteps
18- from fastvideo .v1 . models .dits .base import BaseDiT
19- from fastvideo .v1 . platforms import AttentionBackendEnum
10+ from fastvideo .attention import DistributedAttention , LocalAttention
11+ from fastvideo .configs .models .dits .cosmos import CosmosVideoConfig
12+ from fastvideo .forward_context import get_forward_context
13+ from fastvideo .layers .layernorm import RMSNorm
14+ from fastvideo .layers .linear import ReplicatedLinear
15+ from fastvideo .layers .mlp import MLP
16+ from fastvideo .layers .rotary_embedding import apply_rotary_emb
17+ from fastvideo .layers .visual_embedding import Timesteps
18+ from fastvideo .models .dits .base import BaseDiT
19+ from fastvideo .platforms import AttentionBackendEnum
2020
2121
2222class CosmosPatchEmbed (nn .Module ):
@@ -170,34 +170,49 @@ def __init__(self,
170170 eps = 1e-6 ,
171171 supported_attention_backends : tuple [AttentionBackendEnum , ...]
172172 | None = None ,
173- prefix : str = "" ) -> None :
173+ prefix : str = "" ,
174+ attention_backend : str = "distributed" ) -> None :
174175 assert dim % num_heads == 0
175176 super ().__init__ ()
176177 self .dim = dim
177178 self .num_heads = num_heads
178179 self .head_dim = dim // num_heads
179180 self .qk_norm = qk_norm
180181 self .eps = eps
181-
182- # layers
183- self .to_q = ReplicatedLinear (dim , dim , bias = False )
184- self .to_k = ReplicatedLinear (dim , dim , bias = False )
185- self .to_v = ReplicatedLinear (dim , dim , bias = False )
186- self .to_out = ReplicatedLinear (dim , dim , bias = False )
182+ self .attention_backend = attention_backend
183+
184+ # layers - use standard PyTorch layers when using torch backend
185+ if attention_backend == "torch" :
186+ self .to_q = nn .Linear (dim , dim , bias = False )
187+ self .to_k = nn .Linear (dim , dim , bias = False )
188+ self .to_v = nn .Linear (dim , dim , bias = False )
189+ self .to_out = nn .Linear (dim , dim , bias = False )
190+ else :
191+ self .to_q = ReplicatedLinear (dim , dim , bias = False )
192+ self .to_k = ReplicatedLinear (dim , dim , bias = False )
193+ self .to_v = ReplicatedLinear (dim , dim , bias = False )
194+ self .to_out = ReplicatedLinear (dim , dim , bias = False )
195+
187196 self .norm_q = RMSNorm (self .head_dim ,
188197 eps = eps ) if qk_norm else nn .Identity ()
189198 self .norm_k = RMSNorm (self .head_dim ,
190199 eps = eps ) if qk_norm else nn .Identity ()
191200
192- # Attention mechanism
193- self .attn = DistributedAttention (
194- num_heads = num_heads ,
195- head_size = self .head_dim ,
196- dropout_rate = 0 ,
197- softmax_scale = None ,
198- causal = False ,
199- supported_attention_backends = supported_attention_backends ,
200- prefix = prefix )
201+ # Attention mechanism - select backend
202+ if attention_backend == "torch" :
203+ self .use_torch_attention = True
204+ elif attention_backend == "distributed" :
205+ self .attn = DistributedAttention (
206+ num_heads = num_heads ,
207+ head_size = self .head_dim ,
208+ dropout_rate = 0 ,
209+ softmax_scale = None ,
210+ causal = False ,
211+ supported_attention_backends = supported_attention_backends ,
212+ prefix = prefix )
213+ self .use_torch_attention = False
214+ else :
215+ raise ValueError (f"Unsupported attention backend: { attention_backend } " )
201216
202217 def forward (self ,
203218 hidden_states : torch .Tensor ,
@@ -209,9 +224,14 @@ def forward(self,
209224 encoder_hidden_states = hidden_states
210225
211226 # Get QKV
212- query , _ = self .to_q (hidden_states )
213- key , _ = self .to_k (encoder_hidden_states )
214- value , _ = self .to_v (encoder_hidden_states )
227+ if self .attention_backend == "torch" :
228+ query = self .to_q (hidden_states )
229+ key = self .to_k (encoder_hidden_states )
230+ value = self .to_v (encoder_hidden_states )
231+ else :
232+ query , _ = self .to_q (hidden_states )
233+ key , _ = self .to_k (encoder_hidden_states )
234+ value , _ = self .to_v (encoder_hidden_states )
215235
216236 # Reshape for multi-head attention
217237 query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
@@ -236,12 +256,21 @@ def forward(self,
236256 use_real_unbind_dim = - 2 )
237257
238258 # Attention computation
239- attn_output , _ = self .attn (query , key , value )
240- # attn_output = attn_output.flatten(2)
259+ if self .use_torch_attention :
260+ # Use standard PyTorch scaled dot product attention
261+ attn_output = torch .nn .functional .scaled_dot_product_attention (
262+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0
263+ )
264+ else :
265+ attn_output , _ = self .attn (query , key , value )
266+
241267 attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
242268
243269 # Output projection
244- attn_output , _ = self .to_out (attn_output )
270+ if self .attention_backend == "torch" :
271+ attn_output = self .to_out (attn_output )
272+ else :
273+ attn_output , _ = self .to_out (attn_output )
245274 return attn_output
246275
247276
@@ -255,7 +284,8 @@ def __init__(self,
255284 eps = 1e-6 ,
256285 supported_attention_backends : tuple [AttentionBackendEnum , ...]
257286 | None = None ,
258- prefix : str = "" ) -> None :
287+ prefix : str = "" ,
288+ attention_backend : str = "distributed" ) -> None :
259289 assert dim % num_heads == 0
260290 super ().__init__ ()
261291 self .dim = dim
@@ -264,40 +294,65 @@ def __init__(self,
264294 self .head_dim = dim // num_heads
265295 self .qk_norm = qk_norm
266296 self .eps = eps
267-
268- # layers
269- self .to_q = ReplicatedLinear (dim , dim , bias = False )
270- self .to_k = ReplicatedLinear (cross_attention_dim , dim , bias = False )
271- self .to_v = ReplicatedLinear (cross_attention_dim , dim , bias = False )
272- self .to_out = ReplicatedLinear (dim , dim , bias = False )
297+ self .attention_backend = attention_backend
298+
299+ # layers - use standard PyTorch layers when using torch backend
300+ if attention_backend == "torch" :
301+ self .to_q = nn .Linear (dim , dim , bias = False )
302+ self .to_k = nn .Linear (cross_attention_dim , dim , bias = False )
303+ self .to_v = nn .Linear (cross_attention_dim , dim , bias = False )
304+ self .to_out = nn .Linear (dim , dim , bias = False )
305+ else :
306+ self .to_q = ReplicatedLinear (dim , dim , bias = False )
307+ self .to_k = ReplicatedLinear (cross_attention_dim , dim , bias = False )
308+ self .to_v = ReplicatedLinear (cross_attention_dim , dim , bias = False )
309+ self .to_out = ReplicatedLinear (dim , dim , bias = False )
310+
273311 self .norm_q = RMSNorm (self .head_dim ,
274312 eps = eps ) if qk_norm else nn .Identity ()
275313 self .norm_k = RMSNorm (self .head_dim ,
276314 eps = eps ) if qk_norm else nn .Identity ()
277315
278- # Attention mechanism
279- self .attn = LocalAttention (
280- num_heads = num_heads ,
281- head_size = self .head_dim ,
282- dropout_rate = 0 ,
283- softmax_scale = None ,
284- causal = False ,
285- supported_attention_backends = supported_attention_backends )
316+ # Attention mechanism - select backend
317+ if attention_backend == "torch" :
318+ self .use_torch_attention = True
319+ elif attention_backend == "distributed" :
320+ self .attn = LocalAttention (
321+ num_heads = num_heads ,
322+ head_size = self .head_dim ,
323+ dropout_rate = 0 ,
324+ softmax_scale = None ,
325+ causal = False ,
326+ supported_attention_backends = supported_attention_backends )
327+ self .use_torch_attention = False
328+ else :
329+ raise ValueError (f"Unsupported attention backend: { attention_backend } " )
286330
287331 def forward (self ,
288332 hidden_states : torch .Tensor ,
289333 encoder_hidden_states : torch .Tensor ,
290334 attention_mask : torch .Tensor | None = None ) -> torch .Tensor :
291335
292336 # Get QKV
293- query , _ = self .to_q (hidden_states )
294- key , _ = self .to_k (encoder_hidden_states )
295- value , _ = self .to_v (encoder_hidden_states )
337+ if self .attention_backend == "torch" :
338+ query = self .to_q (hidden_states )
339+ key = self .to_k (encoder_hidden_states )
340+ value = self .to_v (encoder_hidden_states )
341+ else :
342+ query , _ = self .to_q (hidden_states )
343+ key , _ = self .to_k (encoder_hidden_states )
344+ value , _ = self .to_v (encoder_hidden_states )
296345
297346 # Reshape for multi-head attention
298- query = query .unflatten (2 , (self .num_heads , - 1 ))
299- key = key .unflatten (2 , (self .num_heads , - 1 ))
300- value = value .unflatten (2 , (self .num_heads , - 1 ))
347+ if self .use_torch_attention :
348+ # Standard PyTorch attention expects [batch, num_heads, seq_len, head_dim]
349+ query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
350+ key = key .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
351+ value = value .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
352+ else :
353+ query = query .unflatten (2 , (self .num_heads , - 1 ))
354+ key = key .unflatten (2 , (self .num_heads , - 1 ))
355+ value = value .unflatten (2 , (self .num_heads , - 1 ))
301356
302357 # Apply normalization
303358 if self .norm_q is not None :
@@ -306,11 +361,20 @@ def forward(self,
306361 key = self .norm_k .forward_native (key )
307362
308363 # Attention computation
309- attn_output = self .attn (query , key , value )
310- attn_output = attn_output .flatten (2 , 3 ).type_as (query )
364+ if self .use_torch_attention :
365+ attn_output = torch .nn .functional .scaled_dot_product_attention (
366+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0
367+ )
368+ attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
369+ else :
370+ attn_output = self .attn (query , key , value )
371+ attn_output = attn_output .flatten (2 , 3 ).type_as (query )
311372
312373 # Output projection
313- attn_output , _ = self .to_out (attn_output )
374+ if self .attention_backend == "torch" :
375+ attn_output = self .to_out (attn_output )
376+ else :
377+ attn_output , _ = self .to_out (attn_output )
314378 return attn_output
315379
316380
@@ -328,6 +392,7 @@ def __init__(
328392 supported_attention_backends : tuple [AttentionBackendEnum , ...]
329393 | None = None ,
330394 prefix : str = "" ,
395+ attention_backend : str = "distributed" ,
331396 ) -> None :
332397 super ().__init__ ()
333398
@@ -340,7 +405,8 @@ def __init__(
340405 num_heads = num_attention_heads ,
341406 qk_norm = (qk_norm == "rms_norm" ),
342407 supported_attention_backends = supported_attention_backends ,
343- prefix = f"{ prefix } .attn1" )
408+ prefix = f"{ prefix } .attn1" ,
409+ attention_backend = attention_backend )
344410
345411 self .norm2 = CosmosAdaLayerNormZero (in_features = hidden_size ,
346412 hidden_features = adaln_lora_dim )
@@ -350,7 +416,8 @@ def __init__(
350416 num_heads = num_attention_heads ,
351417 qk_norm = (qk_norm == "rms_norm" ),
352418 supported_attention_backends = supported_attention_backends ,
353- prefix = f"{ prefix } .attn2" )
419+ prefix = f"{ prefix } .attn2" ,
420+ attention_backend = attention_backend )
354421
355422 self .norm3 = CosmosAdaLayerNormZero (in_features = hidden_size ,
356423 hidden_features = adaln_lora_dim )
@@ -529,13 +596,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
529596
530597
531598class CosmosTransformer3DModel (BaseDiT ):
532- _fsdp_shard_conditions = CosmosConfig ()._fsdp_shard_conditions
533- _compile_conditions = CosmosConfig ()._compile_conditions
534- _supported_attention_backends = CosmosConfig ()._supported_attention_backends
535- _param_names_mapping = CosmosConfig ()._param_names_mapping
536- _lora_param_names_mapping = CosmosConfig ()._lora_param_names_mapping
599+ _fsdp_shard_conditions = CosmosVideoConfig ()._fsdp_shard_conditions
600+ _compile_conditions = CosmosVideoConfig ()._compile_conditions
601+ _supported_attention_backends = CosmosVideoConfig ()._supported_attention_backends
602+ param_names_mapping = CosmosVideoConfig ().param_names_mapping
603+ lora_param_names_mapping = CosmosVideoConfig ().lora_param_names_mapping
537604
538- def __init__ (self , config : CosmosConfig , hf_config : dict [str , Any ]) -> None :
605+ def __init__ (self , config : CosmosVideoConfig , hf_config : dict [str , Any ]) -> None :
539606 super ().__init__ (config = config , hf_config = hf_config )
540607
541608 inner_dim = config .num_attention_heads * config .attention_head_dim
@@ -586,6 +653,7 @@ def __init__(self, config: CosmosConfig, hf_config: dict[str, Any]) -> None:
586653 out_bias = False ,
587654 supported_attention_backends = self ._supported_attention_backends ,
588655 prefix = f"{ config .prefix } .transformer_blocks.{ i } " ,
656+ attention_backend = config .arch_config .attention_backend ,
589657 ) for i in range (config .num_layers )
590658 ])
591659
0 commit comments