Skip to content

Commit c8027e1

Browse files
committed
fixup format in test_divide
1 parent 54540a8 commit c8027e1

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed
Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
# flake8: noqa: F821, F841
12
import torch
2-
aten = torch.ops.aten
3-
4-
import pytest
3+
import pytest
54

65
import triton
7-
import triton.language as tl
6+
import triton.language as tl
7+
8+
aten = torch.ops.aten
89

910

1011
def patch_kernel(template, to_replace):
@@ -13,25 +14,27 @@ def patch_kernel(template, to_replace):
1314
kernel.src = kernel.src.replace(key, value)
1415
return kernel
1516

17+
1618
@pytest.mark.parametrize("float_div", [True, False])
1719
@pytest.mark.parametrize("floor", [True, False])
1820
@pytest.mark.parametrize("trunc", [True, False])
1921
def test_divide(float_div, floor, trunc, device):
20-
# regression test for various division cases
22+
# regression test for various division cases
2123

22-
@triton.jit
24+
@triton.jit
2325
def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr):
2426
xoffset = tl.program_id(0) * XBLOCK
2527
xindex = xoffset + tl.arange(0, XBLOCK)[:]
2628
xmask = xindex < xnumel
2729
x0 = xindex
2830
tmp0 = tl.load(a + (x0), xmask)
2931
tmp2 = tl.load(b + (x0), xmask)
30-
# custom bits
32+
# custom bits
3133
tmp1 = tmp0.to(tl.float32)
3234
tmp3 = tmp2.to(tl.float32)
3335
tmp4 = tmp1 / tmp3
34-
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), tmp0 // tmp2)
36+
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2),
37+
tmp0 // tmp2)
3538
tmp6 = tmp0 // tmp2
3639
GENERATE_OUTPUTS_HERE
3740

@@ -41,7 +44,8 @@ def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel
4144
outputs_floor = "\n tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor is True else ""
4245
outputs_trunc = "\n tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc is True else ""
4346

44-
divide_kernel = patch_kernel(divide_kernel, {"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"})
47+
divide_kernel = patch_kernel(divide_kernel,
48+
{"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"})
4549

4650
def launch_triton(a, b):
4751
output0 = torch.zeros_like(a)
@@ -57,15 +61,15 @@ def launch_triton(a, b):
5761
divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128)
5862

5963
return (output0, output1, output2, output3, output4)
60-
64+
6165
def launch_torch(a, b):
62-
return (
63-
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a),
64-
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a),
65-
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a),
66-
a / b if float_div is True else torch.zeros_like(a),
67-
a // b if floor is True else torch.zeros_like(a),
68-
)
66+
return (
67+
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a),
68+
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a),
69+
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a),
70+
a / b if float_div is True else torch.zeros_like(a),
71+
a // b if floor is True else torch.zeros_like(a),
72+
)
6973

7074
a = torch.randint(2**32, 2**40, [100, 100], device=device)
7175
b = torch.randint(-10, -1, [100, 100], device=device)
@@ -75,6 +79,6 @@ def launch_torch(a, b):
7579
torch_result = launch_torch(a, b)
7680

7781
for i in range(5):
78-
torch.testing.assert_close(triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}")
79-
80-
82+
torch.testing.assert_close(
83+
triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg:
84+
f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}")

0 commit comments

Comments
 (0)