Skip to content

Commit 4ad17b5

Browse files
authored
update test with multigpu (#207)
* update test * update
1 parent 6936311 commit 4ad17b5

File tree

5 files changed

+10
-24
lines changed

5 files changed

+10
-24
lines changed

tests/pipelines/cogvideox/test_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from videosys import CogVideoXConfig, VideoSysEngine
44

55

6-
@pytest.mark.parametrize("num_gpus", [1])
6+
@pytest.mark.parametrize("num_gpus", [1, 2])
77
def test_base(num_gpus):
88
config = CogVideoXConfig(num_gpus=num_gpus)
99
engine = VideoSysEngine(config)

tests/pipelines/latte/test_latte.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ def test_base(num_gpus):
1313
engine.save_video(video, f"./test_outputs/{prompt}_latte_{num_gpus}.mp4")
1414

1515

16-
@pytest.mark.parametrize("num_gpus", [1, 2])
16+
@pytest.mark.parametrize("num_gpus", [1])
1717
def test_pab(num_gpus):
18-
config = LatteConfig(num_gpus=num_gpus)
18+
config = LatteConfig(num_gpus=num_gpus, enable_pab=True)
1919
engine = VideoSysEngine(config)
2020

2121
prompt = "Sunset over the sea."

tests/pipelines/open_sora/test_open_sora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_base(num_gpus):
1313
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_{num_gpus}.mp4")
1414

1515

16-
@pytest.mark.parametrize("num_gpus", [1, 2])
16+
@pytest.mark.parametrize("num_gpus", [1])
1717
def test_pab(num_gpus):
1818
config = OpenSoraConfig(num_gpus=num_gpus, enable_pab=True)
1919
engine = VideoSysEngine(config)

tests/pipelines/open_sora_plan/test_open_sora_plan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ def test_base(num_gpus):
1010

1111
prompt = "Sunset over the sea."
1212
video = engine.generate(prompt).video[0]
13-
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_pab_{num_gpus}.mp4")
13+
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{num_gpus}.mp4")
1414

1515

16-
@pytest.mark.parametrize("num_gpus", [1, 2])
16+
@pytest.mark.parametrize("num_gpus", [1])
1717
def test_pab(num_gpus):
18-
config = OpenSoraPlanConfig(num_gpus=num_gpus)
18+
config = OpenSoraPlanConfig(num_gpus=num_gpus, enable_pab=True)
1919
engine = VideoSysEngine(config)
2020

2121
prompt = "Sunset over the sea."

videosys/models/transformers/open_sora_plan_transformer_3d.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,23 +1258,9 @@ def __call__(
12581258

12591259
# the output of sdp = (batch, num_heads, seq_len, head_dim)
12601260
# TODO: add support for attn.scale when we move to Torch 2.1
1261-
if self.attention_mode == "flash":
1262-
assert attention_mask is None or torch.all(
1263-
attention_mask.bool()
1264-
), "flash-attn do not support attention_mask"
1265-
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
1266-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1267-
elif self.attention_mode == "xformers":
1268-
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
1269-
hidden_states = F.scaled_dot_product_attention(
1270-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1271-
)
1272-
elif self.attention_mode == "math":
1273-
hidden_states = F.scaled_dot_product_attention(
1274-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1275-
)
1276-
else:
1277-
raise NotImplementedError(f"Found attention_mode: {self.attention_mode}")
1261+
hidden_states = F.scaled_dot_product_attention(
1262+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1263+
)
12781264
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
12791265
hidden_states = hidden_states.to(query.dtype)
12801266

0 commit comments

Comments
 (0)