Skip to content

Commit e3f9f43

Browse files
authored
[BENCH] Use the device fixture for all bench tests (triton-lang#6706)
Also fixed a typo in `test_swiglu`
1 parent e0262f5 commit e3f9f43

File tree

5 files changed

+36
-28
lines changed

5 files changed

+36
-28
lines changed

bench/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption("--device", action="store", default="cuda")
6+
7+
8+
@pytest.fixture
9+
def device(request):
10+
return request.config.getoption("--device")

bench/tests/test_compaction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
(131, 128, 16, 0.6),
1010
(496, 128, 16, 0.),
1111
])
12-
def test_compaction(n_tokens, n_cols, k, p):
13-
device = "cuda"
12+
def test_compaction(n_tokens, n_cols, k, p, device):
1413
yi = torch.rand((n_tokens, n_cols), device=device).argsort(dim=-1)
1514
yi = yi[:, :k].to(torch.int32)
1615
yv = torch.randn((n_tokens, k), dtype=torch.bfloat16, device=device)

bench/tests/test_matmul.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ def mask_indx(idx, n_expts_act):
3939
return idx
4040

4141

42-
def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter):
43-
dev = "cuda"
44-
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=dev, requires_grad=True)
42+
def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"):
43+
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True)
4544
routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards)
4645
routing_data.gate_scal = None
4746
gather_idx = gather_idx if do_gather else None
@@ -50,17 +49,18 @@ def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_
5049

5150

5251
def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype,
53-
has_y_gammas, requires_grad=True, dev="cuda"):
52+
has_y_gammas, requires_grad=True, device="cuda"):
5453
torch.manual_seed(0)
5554
assert mode in {'batched', 'ragged'}
5655
in_m = m * (n_expts_act if gindx is None else 1)
5756
out_m = m * (n_expts_act if sindx is None else 1)
5857
shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k)
59-
x = alloc_rand(shape_x, device=dev, dtype=act_dtype, requires_grad=requires_grad)
60-
w = alloc_rand((n_expts_tot // n_expt_shards, k, n), device=dev, dtype=weight_dtype, requires_grad=requires_grad)
61-
bias = alloc_rand((n_expts_tot // n_expt_shards, n), device=dev, dtype=torch.float32, requires_grad=requires_grad)
62-
gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=dev, dtype=torch.float32, requires_grad=requires_grad)
63-
gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=dev, dtype=torch.float32, requires_grad=requires_grad)
58+
x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad)
59+
w = alloc_rand((n_expts_tot // n_expt_shards, k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
60+
bias = alloc_rand((n_expts_tot // n_expt_shards, n), device=device, dtype=torch.float32,
61+
requires_grad=requires_grad)
62+
gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
63+
gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6464
gs0 = gs0.detach().requires_grad_(requires_grad)
6565
gs1 = gs1.detach().requires_grad_(requires_grad)
6666
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
@@ -75,12 +75,13 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7575
# ---------------
7676

7777

78-
def init_precision(out_dtype, act_use_flexpoint, weight_use_flexpoint, n_expts_tot=1, mx_ctx=MicroscalingCtx()):
78+
def init_precision(out_dtype, act_use_flexpoint, weight_use_flexpoint, n_expts_tot=1, mx_ctx=MicroscalingCtx(),
79+
device="cuda"):
7980
# flexpoint
8081
make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) +
8182
([val0]
82-
if n_expts_tot % 2 else []), dtype=torch.float32, device="cuda")
83-
make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device="cuda")
83+
if n_expts_tot % 2 else []), dtype=torch.float32, device=device)
84+
make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device)
8485
in_flex_data = lambda scale, use_flex: InFlexData(dtype=torch.float8_e5m2, scale=make_scalar(scale)
8586
) if use_flex else InFlexData()
8687
in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=torch.float8_e5m2, scale=make_tensor(
@@ -211,7 +212,7 @@ class Case:
211212
@pytest.mark.parametrize("has_y_gammas", [False, True])
212213
@pytest.mark.parametrize("is_persistent", [False, True])
213214
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
214-
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, swizzle_mx_scale):
215+
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, swizzle_mx_scale, device):
215216
# TODO: remove when Triton FP8 supports proper RTNE
216217
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
217218
pytest.skip("Float8 not tested on A100")
@@ -254,16 +255,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
254255
act_is_float8 = act_dtype.itemsize == 1
255256
weight_is_float8 = weight_dtype.itemsize == 1
256257
precision_opt = init_precision(act_dtype, act_is_float8, weight_is_float8 and not is_mixed_input,
257-
n_expts_tot // n_expt_shards)
258+
n_expts_tot // n_expt_shards, device=device)
258259
# precision_opt.x_pad_trans_requires_flexpoint = False
259260
if mode == "ragged":
260-
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter)
261+
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
262+
device=device)
261263
else:
262264
rdata = gindx = sindx = None
263265
x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act,
264266
n_expt_shards, mode, act_dtype, #
265267
torch.bfloat16 if is_mixed_input else weight_dtype,
266-
has_y_gammas, requires_grad=test_bwd)
268+
has_y_gammas, requires_grad=test_bwd, device=device)
267269
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
268270

269271
if is_mixed_input:

bench/tests/test_routing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from triton_bench.testing import assert_equal
77

88

9-
def init_data(n_tokens, n_expts_tot, dtype=torch.float16):
10-
dev = "cuda"
9+
def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"):
1110
# the reference implementation and the triton implementation do not tie-break experts the same way
1211
randbits = [torch.randperm(n_expts_tot) for _ in range(n_tokens)]
1312
x = [(-1)**i * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(dtype)) for i, bits in enumerate(randbits)]
14-
return torch.stack(x).to(device=dev)
13+
return torch.stack(x).to(device=device)
1514

1615

1716
def ref_expt_data(routing_data, n_gates, block_m):
@@ -46,9 +45,9 @@ def ref_expt_data(routing_data, n_gates, block_m):
4645
@pytest.mark.parametrize("n_tokens", [371, 255, 256, 8192, 1023, 1024])
4746
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (1500, 8)])
4847
@pytest.mark.parametrize("block_m", [64, 128])
49-
def test_op(n_tokens, n_expts_tot, n_expts_act, block_m):
48+
def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, device):
5049
torch.manual_seed(2)
51-
tri_logits = init_data(n_tokens, n_expts_tot).detach()
50+
tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach()
5251
ref_logits = tri_logits.clone()
5352
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act)
5453
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)

bench/tests/test_swiglu.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ def alloc_rand(shape, device, dtype, requires_grad=True):
2626

2727
@pytest.mark.parametrize("M, N", [(1311, 4352)])
2828
@pytest.mark.parametrize("limit", [1e-2, 10])
29-
def test_op(M, N, limit, alpha=0.5):
29+
def test_op(M, N, limit, device, alpha=0.5):
3030
torch.manual_seed(2)
31-
dev = "cuda"
32-
dtype = torch.bfloat16
3331
# initialize expert data
3432
n_expts_tot = 6
3533
n_expts_act = 2
@@ -39,8 +37,8 @@ def test_op(M, N, limit, alpha=0.5):
3937
n_tokens = expt_data[2 * n_expts_tot].sum()
4038

4139
# initialize data
42-
x = alloc_rand([n_tokens, N], device=dev, dtype=dtype)
40+
x = alloc_rand([n_tokens, N], device=device, dtype=torch.bfloat16)
4341
precision_config = PrecisionConfig(limit=limit)
44-
tri_y = swiglu(x, alpha, precision_config, expt_data, n_expts_tot)
42+
tri_y = swiglu(x, alpha, precision_config, routing_data, n_expts_tot)
4543
ref_y = swiglu_torch(x, alpha, precision_config)
4644
assert_close(tri_y, ref_y)

0 commit comments

Comments
 (0)