Skip to content

Commit ed0c542

Browse files
authored
Merge pull request #254 from HKUSTDial:optim_triton_version
Refactor test utilities and add CUDA tensor operation tests
2 parents e76fc45 + fa21a99 commit ed0c542

13 files changed

+1229
-1
lines changed

tests/test_dense_base_backward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_backward_base_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_dense_base_backward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_backward_base_case(
15+
kind="dense",
16+
batch_size=2,
17+
seqlen_q=80,
18+
seqlen_k=96,
19+
num_heads_q=8,
20+
num_heads_kv=4,
21+
head_dim=64,
22+
is_causal=is_causal,
23+
)

tests/test_dense_base_forward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_forward_base_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_dense_base_forward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_forward_base_case(
15+
kind="dense",
16+
batch_size=2,
17+
seqlen_q=96,
18+
seqlen_k=128,
19+
num_heads_q=8,
20+
num_heads_kv=4,
21+
head_dim=64,
22+
is_causal=is_causal,
23+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_backward_varlen_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_dense_varlen_backward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_backward_varlen_case(
15+
kind="dense",
16+
lens_q=[19, 27, 41],
17+
lens_k=[25, 31, 43],
18+
num_heads_q=8,
19+
num_heads_kv=4,
20+
head_dim=64,
21+
is_causal=is_causal,
22+
)

tests/test_dense_varlen_forward.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_forward_varlen_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_dense_varlen_forward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_forward_varlen_case(
15+
kind="dense",
16+
lens_q=[17, 33, 29],
17+
lens_k=[23, 37, 31],
18+
num_heads_q=8,
19+
num_heads_kv=4,
20+
head_dim=64,
21+
is_causal=is_causal,
22+
)

tests/test_gated_base_backward.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_backward_base_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_gated_base_backward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_backward_base_case(
15+
kind="gated",
16+
batch_size=2,
17+
seqlen_q=80,
18+
seqlen_k=96,
19+
num_heads_q=8,
20+
num_heads_kv=4,
21+
head_dim=64,
22+
is_causal=is_causal,
23+
is_logsigmoid_gate=True,
24+
)

tests/test_gated_base_forward.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_forward_base_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_gated_base_forward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_forward_base_case(
15+
kind="gated",
16+
batch_size=2,
17+
seqlen_q=96,
18+
seqlen_k=128,
19+
num_heads_q=8,
20+
num_heads_kv=4,
21+
head_dim=64,
22+
is_causal=is_causal,
23+
is_logsigmoid_gate=True,
24+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_backward_varlen_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_gated_varlen_backward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_backward_varlen_case(
15+
kind="gated",
16+
lens_q=[19, 27, 41],
17+
lens_k=[25, 31, 43],
18+
num_heads_q=8,
19+
num_heads_kv=4,
20+
head_dim=64,
21+
is_causal=is_causal,
22+
is_logsigmoid_gate=True,
23+
)

tests/test_gated_varlen_forward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_forward_varlen_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_gated_varlen_forward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_forward_varlen_case(
15+
kind="gated",
16+
lens_q=[17, 33, 29],
17+
lens_k=[23, 37, 31],
18+
num_heads_q=8,
19+
num_heads_kv=4,
20+
head_dim=64,
21+
is_causal=is_causal,
22+
is_logsigmoid_gate=True,
23+
)

tests/test_sparse_base_backward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_backward_base_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_sparse_base_backward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_backward_base_case(
15+
kind="sparse",
16+
batch_size=2,
17+
seqlen_q=80,
18+
seqlen_k=96,
19+
num_heads_q=8,
20+
num_heads_kv=4,
21+
head_dim=64,
22+
is_causal=is_causal,
23+
)

tests/test_sparse_base_forward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from test_utils import run_forward_base_case, set_seed
5+
6+
pytestmark = pytest.mark.skipif(
7+
not torch.cuda.is_available(), reason="CUDA is required"
8+
)
9+
10+
11+
@pytest.mark.parametrize("is_causal", [False, True])
12+
def test_sparse_base_forward_correctness(is_causal: bool) -> None:
13+
set_seed(0)
14+
run_forward_base_case(
15+
kind="sparse",
16+
batch_size=2,
17+
seqlen_q=96,
18+
seqlen_k=128,
19+
num_heads_q=8,
20+
num_heads_kv=4,
21+
head_dim=64,
22+
is_causal=is_causal,
23+
)

0 commit comments

Comments
 (0)