|
| 1 | +# flake8: noqa: F821, F841 |
| 2 | +import torch |
| 3 | +import pytest |
| 4 | + |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +aten = torch.ops.aten |
| 9 | + |
| 10 | + |
| 11 | +def patch_kernel(template, to_replace): |
| 12 | + kernel = triton.JITFunction(template.fn) |
| 13 | + for key, value in to_replace.items(): |
| 14 | + kernel.src = kernel.src.replace(key, value) |
| 15 | + return kernel |
| 16 | + |
| 17 | + |
| 18 | +@pytest.mark.parametrize("float_div", [True, False]) |
| 19 | +@pytest.mark.parametrize("floor", [True, False]) |
| 20 | +@pytest.mark.parametrize("trunc", [True, False]) |
| 21 | +def test_divide(float_div, floor, trunc, device): |
| 22 | + # regression test for various division cases |
| 23 | + |
| 24 | + @triton.jit |
| 25 | + def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr): |
| 26 | + xoffset = tl.program_id(0) * XBLOCK |
| 27 | + xindex = xoffset + tl.arange(0, XBLOCK)[:] |
| 28 | + xmask = xindex < xnumel |
| 29 | + x0 = xindex |
| 30 | + tmp0 = tl.load(a + (x0), xmask) |
| 31 | + tmp2 = tl.load(b + (x0), xmask) |
| 32 | + # custom bits |
| 33 | + tmp1 = tmp0.to(tl.float32) |
| 34 | + tmp3 = tmp2.to(tl.float32) |
| 35 | + tmp4 = tmp1 / tmp3 |
| 36 | + tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), |
| 37 | + tmp0 // tmp2) |
| 38 | + tmp6 = tmp0 // tmp2 |
| 39 | + GENERATE_OUTPUTS_HERE |
| 40 | + |
| 41 | + torch.manual_seed(0) |
| 42 | + |
| 43 | + outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else "" |
| 44 | + outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else "" |
| 45 | + outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else "" |
| 46 | + |
| 47 | + divide_kernel = patch_kernel(divide_kernel, |
| 48 | + {"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"}) |
| 49 | + |
| 50 | + def launch_triton(a, b): |
| 51 | + output0 = torch.zeros_like(a) |
| 52 | + output1 = torch.zeros_like(a) |
| 53 | + output2 = torch.zeros_like(a) |
| 54 | + output3 = torch.zeros_like(a) |
| 55 | + output4 = torch.zeros_like(a) |
| 56 | + |
| 57 | + n_elements = output0.numel() |
| 58 | + |
| 59 | + grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), ) |
| 60 | + |
| 61 | + divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128) |
| 62 | + |
| 63 | + return (output0, output1, output2, output3, output4) |
| 64 | + |
| 65 | + def launch_torch(a, b): |
| 66 | + return ( |
| 67 | + aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a), |
| 68 | + aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a), |
| 69 | + aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a), |
| 70 | + a / b if float_div is True else torch.zeros_like(a), |
| 71 | + a // b if floor is True else torch.zeros_like(a), |
| 72 | + ) |
| 73 | + |
| 74 | + a = torch.randint(2**32, 2**40, [100, 100], device=device) |
| 75 | + b = torch.randint(-10, -1, [100, 100], device=device) |
| 76 | + |
| 77 | + for iter in range(100): |
| 78 | + triton_result = launch_triton(a, b) |
| 79 | + torch_result = launch_torch(a, b) |
| 80 | + |
| 81 | + for i in range(5): |
| 82 | + torch.testing.assert_close( |
| 83 | + triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: |
| 84 | + f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}") |
0 commit comments