Skip to content

Commit d156461

Browse files
authored
Fix WanVideo (#461)
1 parent 6edf113 commit d156461

File tree

7 files changed

+845
-1
lines changed

7 files changed

+845
-1
lines changed

fastvideo/v1/dataset/parquet_datasets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self,
4343
self.cfg_rate = cfg_rate
4444
self.num_latent_t = num_latent_t
4545
self.local_indices = None
46-
self.plan_output_dir = os.path.join(self.path, "data_plan.json")
46+
self.plan_output_dir = os.path.join(
47+
self.path, f"data_plan_{self.world_size}_{self.sp_world_size}.json")
4748

4849
ranks = get_sp_group().ranks
4950
group_ranks: List[List] = [[] for _ in range(self.world_size)]
@@ -54,6 +55,7 @@ def __init__(self,
5455
# This will be useful when resume training
5556
if os.path.exists(self.plan_output_dir):
5657
print(f"Using existing plan from {self.plan_output_dir}")
58+
dist.barrier()
5759
return
5860

5961
# Find all parquet files recursively, and record num_rows for each file
@@ -87,6 +89,7 @@ def __init__(self,
8789

8890
with open(self.plan_output_dir, "w") as f:
8991
json.dump(plan, f)
92+
dist.barrier()
9093

9194
def __len__(self):
9295
if self.local_indices is None:

fastvideo/v1/models/dits/wanvideo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def forward(
319319
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
320320
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
321321
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
322+
322323
# Apply rotary embeddings
323324
cos, sin = freqs_cis
324325
query, key = _apply_rotary_emb(query, cos, sin,

0 commit comments

Comments
 (0)