Skip to content

Commit 2c78d9b

Browse files
committed
Fix rope scale
1 parent 83c5dc3 commit 2c78d9b

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

fastvideo/configs/models/dits/cosmos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class CosmosArchConfig(DiTArchConfig):
8484
adaln_lora_dim: int = 256
8585
max_size: tuple[int, int, int] = (128, 240, 240)
8686
patch_size: tuple[int, int, int] = (1, 2, 2)
87-
rope_scale: tuple[float, float, float] = (1.0, 4.0, 4.0)
87+
rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0)
8888
concat_padding_mask: bool = True
8989
extra_pos_embed_type: str | None = None
9090
qk_norm: str = "rms_norm"

fastvideo/models/dits/cosmos.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,11 @@ def __init__(
545545
) -> None:
546546
super().__init__()
547547

548+
# Log RoPE parameters
549+
print(f"[FASTVIDEO ROPE INIT] hidden_size={hidden_size}, max_size={max_size}, patch_size={patch_size}, base_fps={base_fps}, rope_scale={rope_scale}")
550+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
551+
f.write(f"[FASTVIDEO ROPE INIT] hidden_size={hidden_size}, max_size={max_size}, patch_size={patch_size}, base_fps={base_fps}, rope_scale={rope_scale}\n")
552+
548553
self.max_size = [
549554
size // patch
550555
for size, patch in zip(max_size, patch_size, strict=False)
@@ -560,20 +565,39 @@ def __init__(
560565
self.w_ntk_factor = rope_scale[2]**(self.dim_w / (self.dim_w - 2))
561566
self.t_ntk_factor = rope_scale[0]**(self.dim_t / (self.dim_t - 2))
562567

568+
print(f"[FASTVIDEO ROPE INIT] dim_h={self.dim_h}, dim_w={self.dim_w}, dim_t={self.dim_t}")
569+
print(f"[FASTVIDEO ROPE INIT] h_ntk_factor={self.h_ntk_factor}, w_ntk_factor={self.w_ntk_factor}, t_ntk_factor={self.t_ntk_factor}")
570+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
571+
f.write(f"[FASTVIDEO ROPE INIT] dim_h={self.dim_h}, dim_w={self.dim_w}, dim_t={self.dim_t}\n")
572+
f.write(f"[FASTVIDEO ROPE INIT] h_ntk_factor={self.h_ntk_factor}, w_ntk_factor={self.w_ntk_factor}, t_ntk_factor={self.t_ntk_factor}\n")
573+
563574
def forward(self,
564575
hidden_states: torch.Tensor,
565576
fps: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
577+
fps = 16
566578
batch_size, num_channels, num_frames, height, width = hidden_states.shape
567579
pe_size = [
568580
num_frames // self.patch_size[0], height // self.patch_size[1],
569581
width // self.patch_size[2]
570582
]
571583
device = hidden_states.device
572584

585+
print(f"[FASTVIDEO ROPE FORWARD] fps={fps}, base_fps={self.base_fps}")
586+
print(f"[FASTVIDEO ROPE FORWARD] pe_size={pe_size}, patch_size={self.patch_size}")
587+
print(f"[FASTVIDEO ROPE FORWARD] hidden_states.shape={hidden_states.shape}")
588+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
589+
f.write(f"[FASTVIDEO ROPE FORWARD] fps={fps}, base_fps={self.base_fps}\n")
590+
f.write(f"[FASTVIDEO ROPE FORWARD] pe_size={pe_size}, patch_size={self.patch_size}\n")
591+
f.write(f"[FASTVIDEO ROPE FORWARD] hidden_states.shape={hidden_states.shape}\n")
592+
573593
h_theta = 10000.0 * self.h_ntk_factor
574594
w_theta = 10000.0 * self.w_ntk_factor
575595
t_theta = 10000.0 * self.t_ntk_factor
576596

597+
print(f"[FASTVIDEO ROPE FORWARD] h_theta={h_theta}, w_theta={w_theta}, t_theta={t_theta}")
598+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
599+
f.write(f"[FASTVIDEO ROPE FORWARD] h_theta={h_theta}, w_theta={w_theta}, t_theta={t_theta}\n")
600+
577601
seq = torch.arange(max(self.max_size),
578602
device=device,
579603
dtype=torch.float32)
@@ -586,10 +610,20 @@ def forward(self,
586610
dim_t_range = (
587611
torch.arange(0, self.dim_t, 2, device=device,
588612
dtype=torch.float32)[:(self.dim_t // 2)] / self.dim_t)
613+
print(f"[FASTVIDEO ROPE FORWARD] max_size={self.max_size}, seq.shape={seq.shape}")
614+
print(f"[FASTVIDEO ROPE FORWARD] dim_h_range.shape={dim_h_range.shape}, dim_w_range.shape={dim_w_range.shape}, dim_t_range.shape={dim_t_range.shape}")
615+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
616+
f.write(f"[FASTVIDEO ROPE FORWARD] max_size={self.max_size}, seq.shape={seq.shape}\n")
617+
f.write(f"[FASTVIDEO ROPE FORWARD] dim_h_range.shape={dim_h_range.shape}, dim_w_range.shape={dim_w_range.shape}, dim_t_range.shape={dim_t_range.shape}\n")
618+
589619
h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
590620
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
591621
temporal_freqs = 1.0 / (t_theta**dim_t_range)
592622

623+
print(f"[FASTVIDEO ROPE FORWARD] h_spatial_freqs.shape={h_spatial_freqs.shape}, w_spatial_freqs.shape={w_spatial_freqs.shape}, temporal_freqs.shape={temporal_freqs.shape}")
624+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
625+
f.write(f"[FASTVIDEO ROPE FORWARD] h_spatial_freqs.shape={h_spatial_freqs.shape}, w_spatial_freqs.shape={w_spatial_freqs.shape}, temporal_freqs.shape={temporal_freqs.shape}\n")
626+
593627
emb_h = torch.outer(seq[:pe_size[1]],
594628
h_spatial_freqs)[None, :, None, :].repeat(
595629
pe_size[0], 1, pe_size[2], 1)
@@ -600,10 +634,16 @@ def forward(self,
600634
# Apply sequence scaling in temporal dimension
601635
if fps is None:
602636
# Images
637+
print(f"[FASTVIDEO ROPE FORWARD] Using image mode (fps=None)")
603638
emb_t = torch.outer(seq[:pe_size[0]], temporal_freqs)
639+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
640+
f.write(f"[FASTVIDEO ROPE FORWARD] Using image mode (fps=None)\n")
604641
else:
605642
# Videos
606-
emb_t = torch.outer(seq[:pe_size[0]] / fps * self.base_fps,
643+
print(f"[FASTVIDEO ROPE FORWARD] Using video mode (fps={fps})")
644+
temporal_scale = seq[:pe_size[0]] / fps * self.base_fps
645+
print(f"[FASTVIDEO ROPE FORWARD] temporal_scale range: {temporal_scale.min().item():.6f} to {temporal_scale.max().item():.6f}")
646+
emb_t = torch.outer(temporal_scale,
607647
temporal_freqs)
608648

609649
emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)

0 commit comments

Comments
 (0)