@@ -168,52 +168,26 @@ def __init__(self,
168168 num_heads : int ,
169169 qk_norm = True ,
170170 eps = 1e-6 ,
171- supported_attention_backends : tuple [AttentionBackendEnum , ...]
172- | None = None ,
173- prefix : str = "" ,
174- attention_backend : str = "distributed" ) -> None :
171+ prefix : str = "" ) -> None :
175172 assert dim % num_heads == 0
176173 super ().__init__ ()
177174 self .dim = dim
178175 self .num_heads = num_heads
179176 self .head_dim = dim // num_heads
180177 self .qk_norm = qk_norm
181178 self .eps = eps
182- self .attention_backend = attention_backend
183179
184180 # layers - use standard PyTorch layers when using torch backend
185- if attention_backend == "torch" :
186- self .to_q = nn .Linear (dim , dim , bias = False )
187- self .to_k = nn .Linear (dim , dim , bias = False )
188- self .to_v = nn .Linear (dim , dim , bias = False )
189- self .to_out = nn .Linear (dim , dim , bias = False )
190- else :
191- self .to_q = ReplicatedLinear (dim , dim , bias = False )
192- self .to_k = ReplicatedLinear (dim , dim , bias = False )
193- self .to_v = ReplicatedLinear (dim , dim , bias = False )
194- self .to_out = ReplicatedLinear (dim , dim , bias = False )
181+ self .to_q = nn .Linear (dim , dim , bias = False )
182+ self .to_k = nn .Linear (dim , dim , bias = False )
183+ self .to_v = nn .Linear (dim , dim , bias = False )
184+ self .to_out = nn .Linear (dim , dim , bias = False )
195185
196186 self .norm_q = RMSNorm (self .head_dim ,
197187 eps = eps ) if qk_norm else nn .Identity ()
198188 self .norm_k = RMSNorm (self .head_dim ,
199189 eps = eps ) if qk_norm else nn .Identity ()
200190
201- # Attention mechanism - select backend
202- if attention_backend == "torch" :
203- self .use_torch_attention = True
204- elif attention_backend == "distributed" :
205- self .attn = DistributedAttention (
206- num_heads = num_heads ,
207- head_size = self .head_dim ,
208- dropout_rate = 0 ,
209- softmax_scale = None ,
210- causal = False ,
211- supported_attention_backends = supported_attention_backends ,
212- prefix = prefix )
213- self .use_torch_attention = False
214- else :
215- raise ValueError (f"Unsupported attention backend: { attention_backend } " )
216-
217191 def forward (self ,
218192 hidden_states : torch .Tensor ,
219193 encoder_hidden_states : torch .Tensor | None = None ,
@@ -224,14 +198,9 @@ def forward(self,
224198 encoder_hidden_states = hidden_states
225199
226200 # Get QKV
227- if self .attention_backend == "torch" :
228- query = self .to_q (hidden_states )
229- key = self .to_k (encoder_hidden_states )
230- value = self .to_v (encoder_hidden_states )
231- else :
232- query , _ = self .to_q (hidden_states )
233- key , _ = self .to_k (encoder_hidden_states )
234- value , _ = self .to_v (encoder_hidden_states )
201+ query = self .to_q (hidden_states )
202+ key = self .to_k (encoder_hidden_states )
203+ value = self .to_v (encoder_hidden_states )
235204
236205 # Reshape for multi-head attention
237206 query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
@@ -256,21 +225,15 @@ def forward(self,
256225 use_real_unbind_dim = - 2 )
257226
258227 # Attention computation
259- if self .use_torch_attention :
260- # Use standard PyTorch scaled dot product attention
261- attn_output = torch .nn .functional .scaled_dot_product_attention (
262- query , key , value , attn_mask = attention_mask , dropout_p = 0.0
263- )
264- else :
265- attn_output , _ = self .attn (query , key , value )
266-
228+ # Use standard PyTorch scaled dot product attention
229+ attn_output = torch .nn .functional .scaled_dot_product_attention (
230+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0
231+ )
267232 attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
268233
269234 # Output projection
270- if self .attention_backend == "torch" :
271- attn_output = self .to_out (attn_output )
272- else :
273- attn_output , _ = self .to_out (attn_output )
235+ attn_output = self .to_out (attn_output )
236+
274237 return attn_output
275238
276239
@@ -282,10 +245,7 @@ def __init__(self,
282245 num_heads : int ,
283246 qk_norm = True ,
284247 eps = 1e-6 ,
285- supported_attention_backends : tuple [AttentionBackendEnum , ...]
286- | None = None ,
287- prefix : str = "" ,
288- attention_backend : str = "distributed" ) -> None :
248+ prefix : str = "" ) -> None :
289249 assert dim % num_heads == 0
290250 super ().__init__ ()
291251 self .dim = dim
@@ -294,65 +254,33 @@ def __init__(self,
294254 self .head_dim = dim // num_heads
295255 self .qk_norm = qk_norm
296256 self .eps = eps
297- self .attention_backend = attention_backend
298257
299258 # layers - use standard PyTorch layers when using torch backend
300- if attention_backend == "torch" :
301- self .to_q = nn .Linear (dim , dim , bias = False )
302- self .to_k = nn .Linear (cross_attention_dim , dim , bias = False )
303- self .to_v = nn .Linear (cross_attention_dim , dim , bias = False )
304- self .to_out = nn .Linear (dim , dim , bias = False )
305- else :
306- self .to_q = ReplicatedLinear (dim , dim , bias = False )
307- self .to_k = ReplicatedLinear (cross_attention_dim , dim , bias = False )
308- self .to_v = ReplicatedLinear (cross_attention_dim , dim , bias = False )
309- self .to_out = ReplicatedLinear (dim , dim , bias = False )
259+ self .to_q = nn .Linear (dim , dim , bias = False )
260+ self .to_k = nn .Linear (cross_attention_dim , dim , bias = False )
261+ self .to_v = nn .Linear (cross_attention_dim , dim , bias = False )
262+ self .to_out = nn .Linear (dim , dim , bias = False )
310263
311264 self .norm_q = RMSNorm (self .head_dim ,
312265 eps = eps ) if qk_norm else nn .Identity ()
313266 self .norm_k = RMSNorm (self .head_dim ,
314267 eps = eps ) if qk_norm else nn .Identity ()
315268
316- # Attention mechanism - select backend
317- if attention_backend == "torch" :
318- self .use_torch_attention = True
319- elif attention_backend == "distributed" :
320- self .attn = LocalAttention (
321- num_heads = num_heads ,
322- head_size = self .head_dim ,
323- dropout_rate = 0 ,
324- softmax_scale = None ,
325- causal = False ,
326- supported_attention_backends = supported_attention_backends )
327- self .use_torch_attention = False
328- else :
329- raise ValueError (f"Unsupported attention backend: { attention_backend } " )
330-
331269 def forward (self ,
332270 hidden_states : torch .Tensor ,
333271 encoder_hidden_states : torch .Tensor ,
334272 attention_mask : torch .Tensor | None = None ) -> torch .Tensor :
335273
336274 # Get QKV
337- if self .attention_backend == "torch" :
338- query = self .to_q (hidden_states )
339- key = self .to_k (encoder_hidden_states )
340- value = self .to_v (encoder_hidden_states )
341- else :
342- query , _ = self .to_q (hidden_states )
343- key , _ = self .to_k (encoder_hidden_states )
344- value , _ = self .to_v (encoder_hidden_states )
275+ query = self .to_q (hidden_states )
276+ key = self .to_k (encoder_hidden_states )
277+ value = self .to_v (encoder_hidden_states )
345278
346279 # Reshape for multi-head attention
347- if self .use_torch_attention :
348280 # Standard PyTorch attention expects [batch, num_heads, seq_len, head_dim]
349- query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
350- key = key .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
351- value = value .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
352- else :
353- query = query .unflatten (2 , (self .num_heads , - 1 ))
354- key = key .unflatten (2 , (self .num_heads , - 1 ))
355- value = value .unflatten (2 , (self .num_heads , - 1 ))
281+ query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
282+ key = key .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
283+ value = value .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
356284
357285 # Apply normalization
358286 if self .norm_q is not None :
@@ -361,20 +289,14 @@ def forward(self,
361289 key = self .norm_k .forward_native (key )
362290
363291 # Attention computation
364- if self .use_torch_attention :
365- attn_output = torch .nn .functional .scaled_dot_product_attention (
366- query , key , value , attn_mask = attention_mask , dropout_p = 0.0
367- )
368- attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
369- else :
370- attn_output = self .attn (query , key , value )
371- attn_output = attn_output .flatten (2 , 3 ).type_as (query )
292+ attn_output = torch .nn .functional .scaled_dot_product_attention (
293+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0
294+ )
295+ attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
372296
373297 # Output projection
374- if self .attention_backend == "torch" :
375- attn_output = self .to_out (attn_output )
376- else :
377- attn_output , _ = self .to_out (attn_output )
298+ attn_output = self .to_out (attn_output )
299+
378300 return attn_output
379301
380302
@@ -389,10 +311,7 @@ def __init__(
389311 adaln_lora_dim : int = 256 ,
390312 qk_norm : str = "rms_norm" ,
391313 out_bias : bool = False ,
392- supported_attention_backends : tuple [AttentionBackendEnum , ...]
393- | None = None ,
394314 prefix : str = "" ,
395- attention_backend : str = "distributed" ,
396315 ) -> None :
397316 super ().__init__ ()
398317
@@ -404,9 +323,7 @@ def __init__(
404323 dim = hidden_size ,
405324 num_heads = num_attention_heads ,
406325 qk_norm = (qk_norm == "rms_norm" ),
407- supported_attention_backends = supported_attention_backends ,
408- prefix = f"{ prefix } .attn1" ,
409- attention_backend = attention_backend )
326+ prefix = f"{ prefix } .attn1" )
410327
411328 self .norm2 = CosmosAdaLayerNormZero (in_features = hidden_size ,
412329 hidden_features = adaln_lora_dim )
@@ -415,9 +332,7 @@ def __init__(
415332 cross_attention_dim = cross_attention_dim ,
416333 num_heads = num_attention_heads ,
417334 qk_norm = (qk_norm == "rms_norm" ),
418- supported_attention_backends = supported_attention_backends ,
419- prefix = f"{ prefix } .attn2" ,
420- attention_backend = attention_backend )
335+ prefix = f"{ prefix } .attn2" )
421336
422337 self .norm3 = CosmosAdaLayerNormZero (in_features = hidden_size ,
423338 hidden_features = adaln_lora_dim )
@@ -598,7 +513,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
598513class CosmosTransformer3DModel (BaseDiT ):
599514 _fsdp_shard_conditions = CosmosVideoConfig ()._fsdp_shard_conditions
600515 _compile_conditions = CosmosVideoConfig ()._compile_conditions
601- _supported_attention_backends = CosmosVideoConfig ()._supported_attention_backends
516+ # _supported_attention_backends = CosmosVideoConfig()._supported_attention_backends
602517 param_names_mapping = CosmosVideoConfig ().param_names_mapping
603518 lora_param_names_mapping = CosmosVideoConfig ().lora_param_names_mapping
604519
@@ -651,9 +566,9 @@ def __init__(self, config: CosmosVideoConfig, hf_config: dict[str, Any]) -> None
651566 adaln_lora_dim = config .adaln_lora_dim ,
652567 qk_norm = config .qk_norm ,
653568 out_bias = False ,
654- supported_attention_backends = self ._supported_attention_backends ,
569+ # supported_attention_backends=self._supported_attention_backends,
655570 prefix = f"{ config .prefix } .transformer_blocks.{ i } " ,
656- attention_backend = config .arch_config .attention_backend ,
571+ # attention_backend=config.arch_config.attention_backend,
657572 ) for i in range (config .num_layers )
658573 ])
659574
0 commit comments