Skip to content

Commit 19aebbe

Browse files
committed
Add regression test 1/?
1 parent 012d2cc commit 19aebbe

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
aten = torch.ops.aten
3+
4+
import pytest
5+
6+
import triton
7+
import triton.language as tl
8+
9+
10+
def patch_kernel(template, to_replace):
11+
kernel = triton.JITFunction(template.fn)
12+
for key, value in to_replace.items():
13+
kernel.src = kernel.src.replace(key, value)
14+
return kernel
15+
16+
def test_divide(device):
17+
# regression test for various division cases
18+
19+
@triton.jit
20+
def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr):
21+
xoffset = tl.program_id(0) * XBLOCK
22+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
23+
xmask = xindex < xnumel
24+
x0 = xindex
25+
tmp0 = tl.load(a + (x0), xmask)
26+
tmp2 = tl.load(b + (x0), xmask)
27+
# custom bits
28+
tmp1 = tmp0.to(tl.float32)
29+
tmp3 = tmp2.to(tl.float32)
30+
tmp4 = tmp1 / tmp3
31+
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), tmp0 // tmp2)
32+
tmp6 = tmp0 // tmp2
33+
tl.store(out_ptr0 + (x0), tmp4, xmask)
34+
tl.store(out_ptr1 + (x0), tmp5, xmask)
35+
tl.store(out_ptr2 + (x0), tmp6, xmask)
36+
tl.store(out_ptr3 + (x0), tmp4, xmask)
37+
tl.store(out_ptr4 + (x0), tmp5, xmask)
38+
39+
torch.manual_seed(0)
40+
41+
def launch_triton(a, b):
42+
output0 = torch.empty_like(a)
43+
output1 = torch.empty_like(a)
44+
output2 = torch.empty_like(a)
45+
output3 = torch.empty_like(a)
46+
output4 = torch.empty_like(a)
47+
48+
n_elements = output0.numel()
49+
50+
grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )
51+
52+
divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128)
53+
54+
return (output0, output1, output2, output3, output4)
55+
56+
def launch_torch(a, b):
57+
return (
58+
aten.div(a, b, rounding_mode=None),
59+
aten.div(a, b, rounding_mode="floor"),
60+
aten.div(a, b, rounding_mode="trunc"),
61+
a / b,
62+
a // b,
63+
)
64+
65+
a = torch.randint(2**32, 2**40, [100, 100], device=device)
66+
b = torch.randint(-10, -1, [100, 100], device=device)
67+
68+
for iter in range(100):
69+
triton_result = launch_triton(a, b)
70+
torch_result = launch_torch(a, b)
71+
72+
for i in range(5):
73+
torch.testing.assert_close(triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: f"Iteration {iter}, {i} failed\n{msg}")
74+
75+

0 commit comments

Comments
 (0)