Skip to content

Commit 11aaaae

Browse files
committed
Clean up
1 parent fc5f77e commit 11aaaae

File tree

2 files changed

+36
-123
lines changed

2 files changed

+36
-123
lines changed

fastvideo/configs/models/dits/cosmos.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ class CosmosArchConfig(DiTArchConfig):
9191
eps: float = 1e-6
9292
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
9393

94-
# Attention backend selection - use torch to match diffusers behavior
95-
attention_backend: str = "torch" # Options: "distributed", "torch"
9694

9795
def __post_init__(self):
9896
super().__post_init__()

fastvideo/models/dits/cosmos.py

Lines changed: 36 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -168,52 +168,26 @@ def __init__(self,
168168
num_heads: int,
169169
qk_norm=True,
170170
eps=1e-6,
171-
supported_attention_backends: tuple[AttentionBackendEnum, ...]
172-
| None = None,
173-
prefix: str = "",
174-
attention_backend: str = "distributed") -> None:
171+
prefix: str = "") -> None:
175172
assert dim % num_heads == 0
176173
super().__init__()
177174
self.dim = dim
178175
self.num_heads = num_heads
179176
self.head_dim = dim // num_heads
180177
self.qk_norm = qk_norm
181178
self.eps = eps
182-
self.attention_backend = attention_backend
183179

184180
# layers - use standard PyTorch layers when using torch backend
185-
if attention_backend == "torch":
186-
self.to_q = nn.Linear(dim, dim, bias=False)
187-
self.to_k = nn.Linear(dim, dim, bias=False)
188-
self.to_v = nn.Linear(dim, dim, bias=False)
189-
self.to_out = nn.Linear(dim, dim, bias=False)
190-
else:
191-
self.to_q = ReplicatedLinear(dim, dim, bias=False)
192-
self.to_k = ReplicatedLinear(dim, dim, bias=False)
193-
self.to_v = ReplicatedLinear(dim, dim, bias=False)
194-
self.to_out = ReplicatedLinear(dim, dim, bias=False)
181+
self.to_q = nn.Linear(dim, dim, bias=False)
182+
self.to_k = nn.Linear(dim, dim, bias=False)
183+
self.to_v = nn.Linear(dim, dim, bias=False)
184+
self.to_out = nn.Linear(dim, dim, bias=False)
195185

196186
self.norm_q = RMSNorm(self.head_dim,
197187
eps=eps) if qk_norm else nn.Identity()
198188
self.norm_k = RMSNorm(self.head_dim,
199189
eps=eps) if qk_norm else nn.Identity()
200190

201-
# Attention mechanism - select backend
202-
if attention_backend == "torch":
203-
self.use_torch_attention = True
204-
elif attention_backend == "distributed":
205-
self.attn = DistributedAttention(
206-
num_heads=num_heads,
207-
head_size=self.head_dim,
208-
dropout_rate=0,
209-
softmax_scale=None,
210-
causal=False,
211-
supported_attention_backends=supported_attention_backends,
212-
prefix=prefix)
213-
self.use_torch_attention = False
214-
else:
215-
raise ValueError(f"Unsupported attention backend: {attention_backend}")
216-
217191
def forward(self,
218192
hidden_states: torch.Tensor,
219193
encoder_hidden_states: torch.Tensor | None = None,
@@ -224,14 +198,9 @@ def forward(self,
224198
encoder_hidden_states = hidden_states
225199

226200
# Get QKV
227-
if self.attention_backend == "torch":
228-
query = self.to_q(hidden_states)
229-
key = self.to_k(encoder_hidden_states)
230-
value = self.to_v(encoder_hidden_states)
231-
else:
232-
query, _ = self.to_q(hidden_states)
233-
key, _ = self.to_k(encoder_hidden_states)
234-
value, _ = self.to_v(encoder_hidden_states)
201+
query = self.to_q(hidden_states)
202+
key = self.to_k(encoder_hidden_states)
203+
value = self.to_v(encoder_hidden_states)
235204

236205
# Reshape for multi-head attention
237206
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
@@ -256,21 +225,15 @@ def forward(self,
256225
use_real_unbind_dim=-2)
257226

258227
# Attention computation
259-
if self.use_torch_attention:
260-
# Use standard PyTorch scaled dot product attention
261-
attn_output = torch.nn.functional.scaled_dot_product_attention(
262-
query, key, value, attn_mask=attention_mask, dropout_p=0.0
263-
)
264-
else:
265-
attn_output, _ = self.attn(query, key, value)
266-
228+
# Use standard PyTorch scaled dot product attention
229+
attn_output = torch.nn.functional.scaled_dot_product_attention(
230+
query, key, value, attn_mask=attention_mask, dropout_p=0.0
231+
)
267232
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
268233

269234
# Output projection
270-
if self.attention_backend == "torch":
271-
attn_output = self.to_out(attn_output)
272-
else:
273-
attn_output, _ = self.to_out(attn_output)
235+
attn_output = self.to_out(attn_output)
236+
274237
return attn_output
275238

276239

@@ -282,10 +245,7 @@ def __init__(self,
282245
num_heads: int,
283246
qk_norm=True,
284247
eps=1e-6,
285-
supported_attention_backends: tuple[AttentionBackendEnum, ...]
286-
| None = None,
287-
prefix: str = "",
288-
attention_backend: str = "distributed") -> None:
248+
prefix: str = "") -> None:
289249
assert dim % num_heads == 0
290250
super().__init__()
291251
self.dim = dim
@@ -294,65 +254,33 @@ def __init__(self,
294254
self.head_dim = dim // num_heads
295255
self.qk_norm = qk_norm
296256
self.eps = eps
297-
self.attention_backend = attention_backend
298257

299258
# layers - use standard PyTorch layers when using torch backend
300-
if attention_backend == "torch":
301-
self.to_q = nn.Linear(dim, dim, bias=False)
302-
self.to_k = nn.Linear(cross_attention_dim, dim, bias=False)
303-
self.to_v = nn.Linear(cross_attention_dim, dim, bias=False)
304-
self.to_out = nn.Linear(dim, dim, bias=False)
305-
else:
306-
self.to_q = ReplicatedLinear(dim, dim, bias=False)
307-
self.to_k = ReplicatedLinear(cross_attention_dim, dim, bias=False)
308-
self.to_v = ReplicatedLinear(cross_attention_dim, dim, bias=False)
309-
self.to_out = ReplicatedLinear(dim, dim, bias=False)
259+
self.to_q = nn.Linear(dim, dim, bias=False)
260+
self.to_k = nn.Linear(cross_attention_dim, dim, bias=False)
261+
self.to_v = nn.Linear(cross_attention_dim, dim, bias=False)
262+
self.to_out = nn.Linear(dim, dim, bias=False)
310263

311264
self.norm_q = RMSNorm(self.head_dim,
312265
eps=eps) if qk_norm else nn.Identity()
313266
self.norm_k = RMSNorm(self.head_dim,
314267
eps=eps) if qk_norm else nn.Identity()
315268

316-
# Attention mechanism - select backend
317-
if attention_backend == "torch":
318-
self.use_torch_attention = True
319-
elif attention_backend == "distributed":
320-
self.attn = LocalAttention(
321-
num_heads=num_heads,
322-
head_size=self.head_dim,
323-
dropout_rate=0,
324-
softmax_scale=None,
325-
causal=False,
326-
supported_attention_backends=supported_attention_backends)
327-
self.use_torch_attention = False
328-
else:
329-
raise ValueError(f"Unsupported attention backend: {attention_backend}")
330-
331269
def forward(self,
332270
hidden_states: torch.Tensor,
333271
encoder_hidden_states: torch.Tensor,
334272
attention_mask: torch.Tensor | None = None) -> torch.Tensor:
335273

336274
# Get QKV
337-
if self.attention_backend == "torch":
338-
query = self.to_q(hidden_states)
339-
key = self.to_k(encoder_hidden_states)
340-
value = self.to_v(encoder_hidden_states)
341-
else:
342-
query, _ = self.to_q(hidden_states)
343-
key, _ = self.to_k(encoder_hidden_states)
344-
value, _ = self.to_v(encoder_hidden_states)
275+
query = self.to_q(hidden_states)
276+
key = self.to_k(encoder_hidden_states)
277+
value = self.to_v(encoder_hidden_states)
345278

346279
# Reshape for multi-head attention
347-
if self.use_torch_attention:
348280
# Standard PyTorch attention expects [batch, num_heads, seq_len, head_dim]
349-
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
350-
key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
351-
value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
352-
else:
353-
query = query.unflatten(2, (self.num_heads, -1))
354-
key = key.unflatten(2, (self.num_heads, -1))
355-
value = value.unflatten(2, (self.num_heads, -1))
281+
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
282+
key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
283+
value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
356284

357285
# Apply normalization
358286
if self.norm_q is not None:
@@ -361,20 +289,14 @@ def forward(self,
361289
key = self.norm_k.forward_native(key)
362290

363291
# Attention computation
364-
if self.use_torch_attention:
365-
attn_output = torch.nn.functional.scaled_dot_product_attention(
366-
query, key, value, attn_mask=attention_mask, dropout_p=0.0
367-
)
368-
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
369-
else:
370-
attn_output = self.attn(query, key, value)
371-
attn_output = attn_output.flatten(2, 3).type_as(query)
292+
attn_output = torch.nn.functional.scaled_dot_product_attention(
293+
query, key, value, attn_mask=attention_mask, dropout_p=0.0
294+
)
295+
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
372296

373297
# Output projection
374-
if self.attention_backend == "torch":
375-
attn_output = self.to_out(attn_output)
376-
else:
377-
attn_output, _ = self.to_out(attn_output)
298+
attn_output = self.to_out(attn_output)
299+
378300
return attn_output
379301

380302

@@ -389,10 +311,7 @@ def __init__(
389311
adaln_lora_dim: int = 256,
390312
qk_norm: str = "rms_norm",
391313
out_bias: bool = False,
392-
supported_attention_backends: tuple[AttentionBackendEnum, ...]
393-
| None = None,
394314
prefix: str = "",
395-
attention_backend: str = "distributed",
396315
) -> None:
397316
super().__init__()
398317

@@ -404,9 +323,7 @@ def __init__(
404323
dim=hidden_size,
405324
num_heads=num_attention_heads,
406325
qk_norm=(qk_norm == "rms_norm"),
407-
supported_attention_backends=supported_attention_backends,
408-
prefix=f"{prefix}.attn1",
409-
attention_backend=attention_backend)
326+
prefix=f"{prefix}.attn1")
410327

411328
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size,
412329
hidden_features=adaln_lora_dim)
@@ -415,9 +332,7 @@ def __init__(
415332
cross_attention_dim=cross_attention_dim,
416333
num_heads=num_attention_heads,
417334
qk_norm=(qk_norm == "rms_norm"),
418-
supported_attention_backends=supported_attention_backends,
419-
prefix=f"{prefix}.attn2",
420-
attention_backend=attention_backend)
335+
prefix=f"{prefix}.attn2")
421336

422337
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size,
423338
hidden_features=adaln_lora_dim)
@@ -598,7 +513,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
598513
class CosmosTransformer3DModel(BaseDiT):
599514
_fsdp_shard_conditions = CosmosVideoConfig()._fsdp_shard_conditions
600515
_compile_conditions = CosmosVideoConfig()._compile_conditions
601-
_supported_attention_backends = CosmosVideoConfig()._supported_attention_backends
516+
# _supported_attention_backends = CosmosVideoConfig()._supported_attention_backends
602517
param_names_mapping = CosmosVideoConfig().param_names_mapping
603518
lora_param_names_mapping = CosmosVideoConfig().lora_param_names_mapping
604519

@@ -651,9 +566,9 @@ def __init__(self, config: CosmosVideoConfig, hf_config: dict[str, Any]) -> None
651566
adaln_lora_dim=config.adaln_lora_dim,
652567
qk_norm=config.qk_norm,
653568
out_bias=False,
654-
supported_attention_backends=self._supported_attention_backends,
569+
# supported_attention_backends=self._supported_attention_backends,
655570
prefix=f"{config.prefix}.transformer_blocks.{i}",
656-
attention_backend=config.arch_config.attention_backend,
571+
#attention_backend=config.arch_config.attention_backend,
657572
) for i in range(config.num_layers)
658573
])
659574

0 commit comments

Comments
 (0)