Skip to content

Commit 1a96ad9

Browse files
committed
Fix dit test
1 parent d763cc5 commit 1a96ad9

File tree

5 files changed

+482
-67
lines changed

5 files changed

+482
-67
lines changed

fastvideo/configs/models/dits/cosmos.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class CosmosArchConfig(DiTArchConfig):
1313
_fsdp_shard_conditions: list = field(
1414
default_factory=lambda: [is_transformer_blocks])
1515

16-
_param_names_mapping: dict = field(
16+
param_names_mapping: dict = field(
1717
default_factory=lambda: {
1818
r"^patch_embed\.(.*)$": r"patch_embed.\1",
1919
r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1",
@@ -51,7 +51,7 @@ class CosmosArchConfig(DiTArchConfig):
5151
r"^proj_out\.(.*)$": r"proj_out.\1",
5252
})
5353

54-
_lora_param_names_mapping: dict = field(
54+
lora_param_names_mapping: dict = field(
5555
default_factory=lambda: {
5656
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
5757
r"transformer_blocks.\1.attn1.to_q.\2",
@@ -90,6 +90,9 @@ class CosmosArchConfig(DiTArchConfig):
9090
qk_norm: str = "rms_norm"
9191
eps: float = 1e-6
9292
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
93+
94+
# Attention backend selection
95+
attention_backend: str = "distributed" # Options: "distributed", "torch"
9396

9497
def __post_init__(self):
9598
super().__post_init__()

fastvideo/models/dits/cosmos.py

Lines changed: 132 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
import torch
88
import torch.nn as nn
99

10-
from fastvideo.v1.attention import DistributedAttention, LocalAttention
11-
from fastvideo.v1.configs.models.dits.cosmos import CosmosConfig
12-
from fastvideo.v1.forward_context import get_forward_context
13-
from fastvideo.v1.layers.layernorm import RMSNorm
14-
from fastvideo.v1.layers.linear import ReplicatedLinear
15-
from fastvideo.v1.layers.mlp import MLP
16-
from fastvideo.v1.layers.rotary_embedding import apply_rotary_emb
17-
from fastvideo.v1.layers.visual_embedding import Timesteps
18-
from fastvideo.v1.models.dits.base import BaseDiT
19-
from fastvideo.v1.platforms import AttentionBackendEnum
10+
from fastvideo.attention import DistributedAttention, LocalAttention
11+
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
12+
from fastvideo.forward_context import get_forward_context
13+
from fastvideo.layers.layernorm import RMSNorm
14+
from fastvideo.layers.linear import ReplicatedLinear
15+
from fastvideo.layers.mlp import MLP
16+
from fastvideo.layers.rotary_embedding import apply_rotary_emb
17+
from fastvideo.layers.visual_embedding import Timesteps
18+
from fastvideo.models.dits.base import BaseDiT
19+
from fastvideo.platforms import AttentionBackendEnum
2020

2121

2222
class CosmosPatchEmbed(nn.Module):
@@ -170,34 +170,49 @@ def __init__(self,
170170
eps=1e-6,
171171
supported_attention_backends: tuple[AttentionBackendEnum, ...]
172172
| None = None,
173-
prefix: str = "") -> None:
173+
prefix: str = "",
174+
attention_backend: str = "distributed") -> None:
174175
assert dim % num_heads == 0
175176
super().__init__()
176177
self.dim = dim
177178
self.num_heads = num_heads
178179
self.head_dim = dim // num_heads
179180
self.qk_norm = qk_norm
180181
self.eps = eps
181-
182-
# layers
183-
self.to_q = ReplicatedLinear(dim, dim, bias=False)
184-
self.to_k = ReplicatedLinear(dim, dim, bias=False)
185-
self.to_v = ReplicatedLinear(dim, dim, bias=False)
186-
self.to_out = ReplicatedLinear(dim, dim, bias=False)
182+
self.attention_backend = attention_backend
183+
184+
# layers - use standard PyTorch layers when using torch backend
185+
if attention_backend == "torch":
186+
self.to_q = nn.Linear(dim, dim, bias=False)
187+
self.to_k = nn.Linear(dim, dim, bias=False)
188+
self.to_v = nn.Linear(dim, dim, bias=False)
189+
self.to_out = nn.Linear(dim, dim, bias=False)
190+
else:
191+
self.to_q = ReplicatedLinear(dim, dim, bias=False)
192+
self.to_k = ReplicatedLinear(dim, dim, bias=False)
193+
self.to_v = ReplicatedLinear(dim, dim, bias=False)
194+
self.to_out = ReplicatedLinear(dim, dim, bias=False)
195+
187196
self.norm_q = RMSNorm(self.head_dim,
188197
eps=eps) if qk_norm else nn.Identity()
189198
self.norm_k = RMSNorm(self.head_dim,
190199
eps=eps) if qk_norm else nn.Identity()
191200

192-
# Attention mechanism
193-
self.attn = DistributedAttention(
194-
num_heads=num_heads,
195-
head_size=self.head_dim,
196-
dropout_rate=0,
197-
softmax_scale=None,
198-
causal=False,
199-
supported_attention_backends=supported_attention_backends,
200-
prefix=prefix)
201+
# Attention mechanism - select backend
202+
if attention_backend == "torch":
203+
self.use_torch_attention = True
204+
elif attention_backend == "distributed":
205+
self.attn = DistributedAttention(
206+
num_heads=num_heads,
207+
head_size=self.head_dim,
208+
dropout_rate=0,
209+
softmax_scale=None,
210+
causal=False,
211+
supported_attention_backends=supported_attention_backends,
212+
prefix=prefix)
213+
self.use_torch_attention = False
214+
else:
215+
raise ValueError(f"Unsupported attention backend: {attention_backend}")
201216

202217
def forward(self,
203218
hidden_states: torch.Tensor,
@@ -209,9 +224,14 @@ def forward(self,
209224
encoder_hidden_states = hidden_states
210225

211226
# Get QKV
212-
query, _ = self.to_q(hidden_states)
213-
key, _ = self.to_k(encoder_hidden_states)
214-
value, _ = self.to_v(encoder_hidden_states)
227+
if self.attention_backend == "torch":
228+
query = self.to_q(hidden_states)
229+
key = self.to_k(encoder_hidden_states)
230+
value = self.to_v(encoder_hidden_states)
231+
else:
232+
query, _ = self.to_q(hidden_states)
233+
key, _ = self.to_k(encoder_hidden_states)
234+
value, _ = self.to_v(encoder_hidden_states)
215235

216236
# Reshape for multi-head attention
217237
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
@@ -236,12 +256,21 @@ def forward(self,
236256
use_real_unbind_dim=-2)
237257

238258
# Attention computation
239-
attn_output, _ = self.attn(query, key, value)
240-
# attn_output = attn_output.flatten(2)
259+
if self.use_torch_attention:
260+
# Use standard PyTorch scaled dot product attention
261+
attn_output = torch.nn.functional.scaled_dot_product_attention(
262+
query, key, value, attn_mask=attention_mask, dropout_p=0.0
263+
)
264+
else:
265+
attn_output, _ = self.attn(query, key, value)
266+
241267
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
242268

243269
# Output projection
244-
attn_output, _ = self.to_out(attn_output)
270+
if self.attention_backend == "torch":
271+
attn_output = self.to_out(attn_output)
272+
else:
273+
attn_output, _ = self.to_out(attn_output)
245274
return attn_output
246275

247276

@@ -255,7 +284,8 @@ def __init__(self,
255284
eps=1e-6,
256285
supported_attention_backends: tuple[AttentionBackendEnum, ...]
257286
| None = None,
258-
prefix: str = "") -> None:
287+
prefix: str = "",
288+
attention_backend: str = "distributed") -> None:
259289
assert dim % num_heads == 0
260290
super().__init__()
261291
self.dim = dim
@@ -264,40 +294,65 @@ def __init__(self,
264294
self.head_dim = dim // num_heads
265295
self.qk_norm = qk_norm
266296
self.eps = eps
267-
268-
# layers
269-
self.to_q = ReplicatedLinear(dim, dim, bias=False)
270-
self.to_k = ReplicatedLinear(cross_attention_dim, dim, bias=False)
271-
self.to_v = ReplicatedLinear(cross_attention_dim, dim, bias=False)
272-
self.to_out = ReplicatedLinear(dim, dim, bias=False)
297+
self.attention_backend = attention_backend
298+
299+
# layers - use standard PyTorch layers when using torch backend
300+
if attention_backend == "torch":
301+
self.to_q = nn.Linear(dim, dim, bias=False)
302+
self.to_k = nn.Linear(cross_attention_dim, dim, bias=False)
303+
self.to_v = nn.Linear(cross_attention_dim, dim, bias=False)
304+
self.to_out = nn.Linear(dim, dim, bias=False)
305+
else:
306+
self.to_q = ReplicatedLinear(dim, dim, bias=False)
307+
self.to_k = ReplicatedLinear(cross_attention_dim, dim, bias=False)
308+
self.to_v = ReplicatedLinear(cross_attention_dim, dim, bias=False)
309+
self.to_out = ReplicatedLinear(dim, dim, bias=False)
310+
273311
self.norm_q = RMSNorm(self.head_dim,
274312
eps=eps) if qk_norm else nn.Identity()
275313
self.norm_k = RMSNorm(self.head_dim,
276314
eps=eps) if qk_norm else nn.Identity()
277315

278-
# Attention mechanism
279-
self.attn = LocalAttention(
280-
num_heads=num_heads,
281-
head_size=self.head_dim,
282-
dropout_rate=0,
283-
softmax_scale=None,
284-
causal=False,
285-
supported_attention_backends=supported_attention_backends)
316+
# Attention mechanism - select backend
317+
if attention_backend == "torch":
318+
self.use_torch_attention = True
319+
elif attention_backend == "distributed":
320+
self.attn = LocalAttention(
321+
num_heads=num_heads,
322+
head_size=self.head_dim,
323+
dropout_rate=0,
324+
softmax_scale=None,
325+
causal=False,
326+
supported_attention_backends=supported_attention_backends)
327+
self.use_torch_attention = False
328+
else:
329+
raise ValueError(f"Unsupported attention backend: {attention_backend}")
286330

287331
def forward(self,
288332
hidden_states: torch.Tensor,
289333
encoder_hidden_states: torch.Tensor,
290334
attention_mask: torch.Tensor | None = None) -> torch.Tensor:
291335

292336
# Get QKV
293-
query, _ = self.to_q(hidden_states)
294-
key, _ = self.to_k(encoder_hidden_states)
295-
value, _ = self.to_v(encoder_hidden_states)
337+
if self.attention_backend == "torch":
338+
query = self.to_q(hidden_states)
339+
key = self.to_k(encoder_hidden_states)
340+
value = self.to_v(encoder_hidden_states)
341+
else:
342+
query, _ = self.to_q(hidden_states)
343+
key, _ = self.to_k(encoder_hidden_states)
344+
value, _ = self.to_v(encoder_hidden_states)
296345

297346
# Reshape for multi-head attention
298-
query = query.unflatten(2, (self.num_heads, -1))
299-
key = key.unflatten(2, (self.num_heads, -1))
300-
value = value.unflatten(2, (self.num_heads, -1))
347+
if self.use_torch_attention:
348+
# Standard PyTorch attention expects [batch, num_heads, seq_len, head_dim]
349+
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
350+
key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
351+
value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
352+
else:
353+
query = query.unflatten(2, (self.num_heads, -1))
354+
key = key.unflatten(2, (self.num_heads, -1))
355+
value = value.unflatten(2, (self.num_heads, -1))
301356

302357
# Apply normalization
303358
if self.norm_q is not None:
@@ -306,11 +361,20 @@ def forward(self,
306361
key = self.norm_k.forward_native(key)
307362

308363
# Attention computation
309-
attn_output = self.attn(query, key, value)
310-
attn_output = attn_output.flatten(2, 3).type_as(query)
364+
if self.use_torch_attention:
365+
attn_output = torch.nn.functional.scaled_dot_product_attention(
366+
query, key, value, attn_mask=attention_mask, dropout_p=0.0
367+
)
368+
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
369+
else:
370+
attn_output = self.attn(query, key, value)
371+
attn_output = attn_output.flatten(2, 3).type_as(query)
311372

312373
# Output projection
313-
attn_output, _ = self.to_out(attn_output)
374+
if self.attention_backend == "torch":
375+
attn_output = self.to_out(attn_output)
376+
else:
377+
attn_output, _ = self.to_out(attn_output)
314378
return attn_output
315379

316380

@@ -328,6 +392,7 @@ def __init__(
328392
supported_attention_backends: tuple[AttentionBackendEnum, ...]
329393
| None = None,
330394
prefix: str = "",
395+
attention_backend: str = "distributed",
331396
) -> None:
332397
super().__init__()
333398

@@ -340,7 +405,8 @@ def __init__(
340405
num_heads=num_attention_heads,
341406
qk_norm=(qk_norm == "rms_norm"),
342407
supported_attention_backends=supported_attention_backends,
343-
prefix=f"{prefix}.attn1")
408+
prefix=f"{prefix}.attn1",
409+
attention_backend=attention_backend)
344410

345411
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size,
346412
hidden_features=adaln_lora_dim)
@@ -350,7 +416,8 @@ def __init__(
350416
num_heads=num_attention_heads,
351417
qk_norm=(qk_norm == "rms_norm"),
352418
supported_attention_backends=supported_attention_backends,
353-
prefix=f"{prefix}.attn2")
419+
prefix=f"{prefix}.attn2",
420+
attention_backend=attention_backend)
354421

355422
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size,
356423
hidden_features=adaln_lora_dim)
@@ -529,13 +596,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
529596

530597

531598
class CosmosTransformer3DModel(BaseDiT):
532-
_fsdp_shard_conditions = CosmosConfig()._fsdp_shard_conditions
533-
_compile_conditions = CosmosConfig()._compile_conditions
534-
_supported_attention_backends = CosmosConfig()._supported_attention_backends
535-
_param_names_mapping = CosmosConfig()._param_names_mapping
536-
_lora_param_names_mapping = CosmosConfig()._lora_param_names_mapping
599+
_fsdp_shard_conditions = CosmosVideoConfig()._fsdp_shard_conditions
600+
_compile_conditions = CosmosVideoConfig()._compile_conditions
601+
_supported_attention_backends = CosmosVideoConfig()._supported_attention_backends
602+
param_names_mapping = CosmosVideoConfig().param_names_mapping
603+
lora_param_names_mapping = CosmosVideoConfig().lora_param_names_mapping
537604

538-
def __init__(self, config: CosmosConfig, hf_config: dict[str, Any]) -> None:
605+
def __init__(self, config: CosmosVideoConfig, hf_config: dict[str, Any]) -> None:
539606
super().__init__(config=config, hf_config=hf_config)
540607

541608
inner_dim = config.num_attention_heads * config.attention_head_dim
@@ -586,6 +653,7 @@ def __init__(self, config: CosmosConfig, hf_config: dict[str, Any]) -> None:
586653
out_bias=False,
587654
supported_attention_backends=self._supported_attention_backends,
588655
prefix=f"{config.prefix}.transformer_blocks.{i}",
656+
attention_backend=config.arch_config.attention_backend,
589657
) for i in range(config.num_layers)
590658
])
591659

fastvideo/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def register_model(
236236

237237
def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn:
238238
all_supported_archs = self.get_supported_archs()
239-
239+
print('all_supported1', all_supported_archs)
240240
if any(arch in all_supported_archs for arch in architectures):
241241
raise ValueError(
242242
f"Model architectures {architectures} failed "

0 commit comments

Comments
 (0)