Skip to content

Commit 54540a8

Browse files
committed
Parametrize test_divide (2/?)
1 parent 19aebbe commit 54540a8

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

python/test/regression/test_divide.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ def patch_kernel(template, to_replace):
1212
for key, value in to_replace.items():
1313
kernel.src = kernel.src.replace(key, value)
1414
return kernel
15-
16-
def test_divide(device):
15+
16+
@pytest.mark.parametrize("float_div", [True, False])
17+
@pytest.mark.parametrize("floor", [True, False])
18+
@pytest.mark.parametrize("trunc", [True, False])
19+
def test_divide(float_div, floor, trunc, device):
1720
# regression test for various division cases
1821

1922
@triton.jit
@@ -30,36 +33,38 @@ def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel
3033
tmp4 = tmp1 / tmp3
3134
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), tmp0 // tmp2)
3235
tmp6 = tmp0 // tmp2
33-
tl.store(out_ptr0 + (x0), tmp4, xmask)
34-
tl.store(out_ptr1 + (x0), tmp5, xmask)
35-
tl.store(out_ptr2 + (x0), tmp6, xmask)
36-
tl.store(out_ptr3 + (x0), tmp4, xmask)
37-
tl.store(out_ptr4 + (x0), tmp5, xmask)
36+
GENERATE_OUTPUTS_HERE
3837

3938
torch.manual_seed(0)
4039

40+
outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div is True else ""
41+
outputs_floor = "\n tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor is True else ""
42+
outputs_trunc = "\n tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc is True else ""
43+
44+
divide_kernel = patch_kernel(divide_kernel, {"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"})
45+
4146
def launch_triton(a, b):
42-
output0 = torch.empty_like(a)
43-
output1 = torch.empty_like(a)
44-
output2 = torch.empty_like(a)
45-
output3 = torch.empty_like(a)
46-
output4 = torch.empty_like(a)
47+
output0 = torch.zeros_like(a)
48+
output1 = torch.zeros_like(a)
49+
output2 = torch.zeros_like(a)
50+
output3 = torch.zeros_like(a)
51+
output4 = torch.zeros_like(a)
4752

4853
n_elements = output0.numel()
4954

5055
grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )
51-
56+
5257
divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128)
5358

5459
return (output0, output1, output2, output3, output4)
5560

5661
def launch_torch(a, b):
5762
return (
58-
aten.div(a, b, rounding_mode=None),
59-
aten.div(a, b, rounding_mode="floor"),
60-
aten.div(a, b, rounding_mode="trunc"),
61-
a / b,
62-
a // b,
63+
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a),
64+
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a),
65+
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a),
66+
a / b if float_div is True else torch.zeros_like(a),
67+
a // b if floor is True else torch.zeros_like(a),
6368
)
6469

6570
a = torch.randint(2**32, 2**40, [100, 100], device=device)
@@ -70,6 +75,6 @@ def launch_torch(a, b):
7075
torch_result = launch_torch(a, b)
7176

7277
for i in range(5):
73-
torch.testing.assert_close(triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: f"Iteration {iter}, {i} failed\n{msg}")
78+
torch.testing.assert_close(triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}")
7479

7580

0 commit comments

Comments
 (0)