Skip to content

Commit 6c58a5b

Browse files
[Bugfix]Fix VSA sp for training/inference (#574)
1 parent 48d9f61 commit 6c58a5b

File tree

6 files changed

+17
-18
lines changed

6 files changed

+17
-18
lines changed

.github/workflows/pr-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ jobs:
264264
with:
265265
job_id: "training-test-VSA"
266266
gpu_type: "NVIDIA H100 NVL"
267-
gpu_count: 1
267+
gpu_count: 2
268268
volume_size: 100
269269
disk_size: 100
270270
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"
@@ -284,7 +284,7 @@ jobs:
284284
with:
285285
job_id: "inference-test-STA"
286286
gpu_type: "NVIDIA H100 NVL"
287-
gpu_count: 1
287+
gpu_count: 2
288288
volume_size: 100
289289
disk_size: 100
290290
image: "ghcr.io/${{ github.repository }}/fastvideo-dev:py3.12-latest"

fastvideo/v1/tests/inference/STA/test_STA_inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
NUM_NODES = "1"
8-
NUM_GPUS_PER_NODE = "1"
8+
NUM_GPUS_PER_NODE = "2"
99

1010
# Set environment variables
1111
os.environ["FASTVIDEO_ATTENTION_CONFIG"] = "assets/mask_strategy_wan.json"
@@ -17,9 +17,9 @@ def test_inference():
1717
cmd = [
1818
"fastvideo", "generate",
1919
"--model-path", "Wan-AI/Wan2.1-T2V-14B-Diffusers",
20-
"--sp-size", "1",
21-
"--tp-size", "1",
22-
"--num-gpus", "1",
20+
"--sp-size", "2",
21+
"--tp-size", "2",
22+
"--num-gpus", "2",
2323
"--height", "768",
2424
"--width", "1280",
2525
"--num-frames", "69",

fastvideo/v1/tests/modal/pr_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def run_ssim_tests():
6969
def run_training_tests():
7070
run_test("wandb login $WANDB_API_KEY && pytest ./fastvideo/v1/tests/training/Vanilla -srP")
7171

72-
@app.function(gpu="H100:1", image=image, timeout=1800, secrets=[modal.Secret.from_dict({"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", "")})])
72+
@app.function(gpu="H100:2", image=image, timeout=1800, secrets=[modal.Secret.from_dict({"WANDB_API_KEY": os.environ.get("WANDB_API_KEY", "")})])
7373
def run_training_tests_VSA():
7474
run_test("wandb login $WANDB_API_KEY && pytest ./fastvideo/v1/tests/training/VSA -srP")
7575

76-
@app.function(gpu="H100:1", image=image, timeout=1800)
76+
@app.function(gpu="H100:2", image=image, timeout=1800)
7777
def run_inference_tests_STA():
7878
run_test("pytest ./fastvideo/v1/tests/inference/STA -srP")
7979

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"grad_norm":0.478515625,"_runtime":95.727033597,"_wandb":{"runtime":95},"_step":5,"validation_videos_50_steps":{"videos":[{"_type":"video-file","sha256":"42a1c311521a9d460db788713be1cbf2db767494e02619b43be5bf3eed8381d8","size":158632,"path":"media/videos/validation_videos_50_steps_0_42a1c311521a9d460db7.mp4"},{"path":"media/videos/validation_videos_50_steps_0_818505095b4b5e8b7f51.mp4","_type":"video-file","sha256":"818505095b4b5e8b7f511012d45f04d151ce3344bc058fc0f3225a414a851e4a","size":147825},{"sha256":"fc334ba9ed5e66c8527ee3b408e3be2d76167fef03588bf2840f4a0792f2fe34","size":136933,"path":"media/videos/validation_videos_50_steps_0_fc334ba9ed5e66c8527e.mp4","_type":"video-file"},{"size":201797,"path":"media/videos/validation_videos_50_steps_0_ccd98f6f907635d266a7.mp4","_type":"video-file","sha256":"ccd98f6f907635d266a74783688e7ecf1dac752d79d72d69eab9ef0e3f7413eb"},{"_type":"video-file","sha256":"ca79f40a0aed38f676f12779b349ce40e9e3fb7f36c578f49a20854c70508fb4","size":147114,"path":"media/videos/validation_videos_50_steps_0_ca79f40a0aed38f676f1.mp4"},{"size":175104,"path":"media/videos/validation_videos_50_steps_0_32c9b33ff920c17e5881.mp4","_type":"video-file","sha256":"32c9b33ff920c17e588133d7a27aa400ff3dc529b01ed4f16ac4d6bb2afa0f00"},{"sha256":"2cf520bfb93401c914e93c87ef791c2f12a4e043b95dfdc98115c930e11dfe67","size":139655,"path":"media/videos/validation_videos_50_steps_0_2cf520bfb93401c914e9.mp4","_type":"video-file"},{"_type":"video-file","sha256":"1d73aba17ce582c7aef4af4d64079e3e9d3df205634eff453446bdaf2340b214","size":149028,"path":"media/videos/validation_videos_50_steps_0_1d73aba17ce582c7aef4.mp4"}],"captions":false,"_type":"videos","count":8},"train_loss":0.08922439813613892,"_timestamp":1.750202051751466e+09,"avg_step_time":0.7536672964692116,"step_time":0.4742048177868128,"learning_rate":1e-05,"vsa_sparsity":0.05}
1+
{"step_time":0.6983645600266755,"_wandb":{"runtime":107},"grad_norm":0.50390625,"avg_step_time":1.002151239803061,"_step":5,"validation_videos_50_steps":{"captions":false,"_type":"videos","count":8,"videos":[{"size":159131,"path":"media/videos/validation_videos_50_steps_0_dc447599dbe48350e9c9.mp4","_type":"video-file","sha256":"dc447599dbe48350e9c920f4971e1786bde580dd20d52b4aa147ae8d3dc564d6"},{"_type":"video-file","sha256":"4e283876ddfbf5a2cb6f8aca07a39f832b5f806fbefde8854bc73ed904ff20ee","size":160315,"path":"media/videos/validation_videos_50_steps_0_4e283876ddfbf5a2cb6f.mp4"},{"size":135225,"path":"media/videos/validation_videos_50_steps_0_78185c41e1935306e93c.mp4","_type":"video-file","sha256":"78185c41e1935306e93c2d416ee40b31d038abffb029cb5bfb11c2a634eb2fcf"},{"_type":"video-file","sha256":"27e9819d002d3f63c8918bbdc5bf2857b5effe0caf5d5b8374b9c590fc6432eb","size":197873,"path":"media/videos/validation_videos_50_steps_0_27e9819d002d3f63c891.mp4"},{"_type":"video-file","sha256":"46fe548e86144ca60a9396fcd15e8788d3a05c066c6819ccdd6a9041feaaec8f","size":170601,"path":"media/videos/validation_videos_50_steps_0_46fe548e86144ca60a93.mp4"},{"sha256":"91ec338774bec870b9c5c81be4330a8f0f2535124cb372e881c3703e6b65ed77","size":164462,"path":"media/videos/validation_videos_50_steps_0_91ec338774bec870b9c5.mp4","_type":"video-file"},{"_type":"video-file","sha256":"ee4e811080a619215fd7541b39203dfb69c2db4cdadbfdd7a234a2cff684f6f3","size":139435,"path":"media/videos/validation_videos_50_steps_0_ee4e811080a619215fd7.mp4"},{"sha256":"22e31e048ba5e5b9d6587d7306904d6edc6c618b8f77c3ceb05ba2d602309274","size":147072,"path":"media/videos/validation_videos_50_steps_0_22e31e048ba5e5b9d658.mp4","_type":"video-file"}]},"_timestamp":1.75118195270901e+09,"vsa_sparsity":0.05,"learning_rate":1e-05,"train_loss":0.19960195198655128,"_runtime":107.325113071}

fastvideo/v1/tests/training/VSA/test_training_loss_VSA.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
reference_wandb_summary_file = "fastvideo/v1/tests/training/VSA/reference_wandb_summary_VSA.json"
1616

1717
NUM_NODES = "1"
18-
NUM_GPUS_PER_NODE = "1"
18+
NUM_GPUS_PER_NODE = "2"
1919

2020
os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "VIDEO_SPARSE_ATTN"
2121

@@ -35,14 +35,14 @@ def run_worker():
3535
"--validation_preprocessed_path", "data/mini_dataset_i2v_VSA/validation_parquet_dataset",
3636
"--train_batch_size", "1",
3737
"--num_latent_t", "4",
38-
"--num_gpus", "1",
39-
"--sp_size", "1",
40-
"--tp_size", "1",
38+
"--num_gpus", "2",
39+
"--sp_size", "2",
40+
"--tp_size", "2",
4141
"--hsdp_replicate_dim", "1",
42-
"--hsdp_shard_dim", "1",
42+
"--hsdp_shard_dim", "2",
4343
"--train_sp_batch_size", "1",
4444
"--dataloader_num_workers", "4",
45-
"--gradient_accumulation_steps", "1",
45+
"--gradient_accumulation_steps", "2",
4646
"--max_train_steps", "5",
4747
"--learning_rate", "1e-5",
4848
"--mixed_precision", "bf16",
@@ -110,7 +110,7 @@ def test_distributed_training():
110110
fields_and_thresholds = {
111111
'avg_step_time': 1.0,
112112
'grad_norm': 0.1,
113-
'step_time': 0.5,
113+
'step_time': 1.0,
114114
'train_loss': 0.001
115115
}
116116

fastvideo/v1/training/training_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ def _build_attention_metadata(
245245
current_vsa_sparsity = training_batch.current_vsa_sparsity
246246

247247
if vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN":
248-
249248
dit_seq_shape = [
250-
latents.shape[2] // patch_size[0],
249+
latents.shape[2] * self.sp_world_size // patch_size[0],
251250
latents.shape[3] // patch_size[1],
252251
latents.shape[4] // patch_size[2]
253252
]

0 commit comments

Comments
 (0)