|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | +from torch import Tensor |
| 7 | +from torch._inductor.runtime import triton_helpers |
| 8 | +from typing import Callable |
| 9 | +from helion.runtime import default_launcher as _default_launcher |
| 10 | + |
| 11 | +DEVICE = 'xpu' |
| 12 | + |
| 13 | + |
| 14 | +@triton.jit |
| 15 | +def _helion_matmul(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, |
| 16 | + _BLOCK_SIZE_2: tl.constexpr): |
| 17 | + num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) |
| 18 | + num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) |
| 19 | + inner_2d_pid = tl.program_id(0) |
| 20 | + num_pid_in_group = 64 * num_pid_n |
| 21 | + group_id = inner_2d_pid // num_pid_in_group |
| 22 | + first_pid_m = group_id * 64 |
| 23 | + group_size_m = min(num_pid_m - first_pid_m, 64) |
| 24 | + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m |
| 25 | + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m |
| 26 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 27 | + offset_1 = pid_1 * _BLOCK_SIZE_1 |
| 28 | + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) |
| 29 | + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): |
| 30 | + acc_copy = acc |
| 31 | + load = tl.load( |
| 32 | + tl.make_block_ptr(x, [1024, 1024], [1024, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), |
| 33 | + boundary_check=[0, 1], padding_option='zero') |
| 34 | + load_1 = tl.load( |
| 35 | + tl.make_block_ptr(y, [1024, 1024], [1024, 1], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [1, 0]), |
| 36 | + boundary_check=[0, 1], padding_option='zero') |
| 37 | + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy, input_precision='tf32', |
| 38 | + out_dtype=tl.float32) |
| 39 | + load_2 = tl.load( |
| 40 | + tl.make_block_ptr(epilogue_closure_0, [1, 1024], [1024, 1], [0, offset_1], [1, _BLOCK_SIZE_1], [1, 0])) |
| 41 | + v_0 = tl.cast(load_2, tl.float32) |
| 42 | + v_1 = acc + v_0 |
| 43 | + v_2 = tl.full([], 0, tl.int32) |
| 44 | + v_3 = triton_helpers.maximum(v_2, v_1) |
| 45 | + v_4 = tl.cast(v_3, tl.float16) |
| 46 | + tl.store( |
| 47 | + tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), |
| 48 | + v_4, boundary_check=[0, 1]) |
| 49 | + |
| 50 | + |
| 51 | +def matmul(x, y, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc, *, |
| 52 | + _launcher=_default_launcher): |
| 53 | + """ |
| 54 | + Performs matrix multiplication of x and y with an optional epilogue function. |
| 55 | + Args: |
| 56 | + x (Tensor): Left matrix of shape [m, k]. |
| 57 | + y (Tensor): Right matrix of shape [k, n]. |
| 58 | + epilogue (Callable, optional): Function applied to the accumulator and tile indices |
| 59 | + after the matmul. Defaults to identity (no change). |
| 60 | + Returns: |
| 61 | + Tensor: Resulting matrix of shape [m, n]. |
| 62 | + """ |
| 63 | + m, k = x.size() |
| 64 | + k2, n = y.size() |
| 65 | + assert k == k2, f'size mismatch {k} != {k2}' |
| 66 | + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) |
| 67 | + _BLOCK_SIZE_0 = 64 |
| 68 | + _BLOCK_SIZE_1 = 64 |
| 69 | + _BLOCK_SIZE_2 = 16 |
| 70 | + _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1), ), x, y, |
| 71 | + epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, |
| 72 | + num_stages=4) |
| 73 | + return out |
| 74 | + |
| 75 | + |
| 76 | +bias = torch.ones([1, 1024], device=DEVICE, dtype=torch.float16) |
| 77 | +args = ( |
| 78 | + torch.ones([1024, 1024], device=DEVICE, dtype=torch.float16), |
| 79 | + torch.ones([1024, 1024], device=DEVICE, dtype=torch.float16), |
| 80 | + lambda acc, tile: torch.relu(acc + bias[tile]), |
| 81 | +) |
| 82 | + |
| 83 | +bias.fill_(0.7) |
| 84 | +args[0].fill_(0.1) |
| 85 | +args[1].fill_(0.2) |
| 86 | + |
| 87 | + |
| 88 | +def make_epilogue(bias): |
| 89 | + |
| 90 | + def epilogue(acc, tile): |
| 91 | + return acc + bias[tile[0], tile[1]] |
| 92 | + |
| 93 | + return epilogue |
| 94 | + |
| 95 | + |
| 96 | +epilogue = make_epilogue(bias) |
| 97 | + |
| 98 | +out = matmul(args[0], args[1], epilogue) |
| 99 | +torch.xpu.synchronize() |
| 100 | +torch.testing.assert_close(out, torch.relu(args[0] @ args[1] + bias)) |
0 commit comments