|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | +# The purpose of this kernel and test is to catch incorrectly optimized kernels |
| 7 | +# where copy elimination happens erroneously in the absence of explicit memory allocation. |
| 8 | +# Such optimization bugs can result in incorrect behavior when swapping two arrays, |
| 9 | +# particularly when both arrays unintentionally end up with the same data due to |
| 10 | +# missing intermediate storage or mismanaged memory access. |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def swap_kernel( |
| 14 | + x_ptr, # *Pointer* to first inout vector. |
| 15 | + y_ptr, # *Pointer* to second inout vector. |
| 16 | + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. |
| 17 | + # NOTE: `constexpr` so it can be used as a shape value. |
| 18 | +): |
| 19 | + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. |
| 20 | + block_start = pid * BLOCK_SIZE |
| 21 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 22 | + x = tl.load(x_ptr + offsets) |
| 23 | + y = tl.load(y_ptr + offsets) |
| 24 | + tl.store(x_ptr + offsets, y) |
| 25 | + tl.store(y_ptr + offsets, x) |
| 26 | + |
| 27 | + |
| 28 | +def swap(x: torch.Tensor, y: torch.Tensor): |
| 29 | + n_elements = x.numel() |
| 30 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 31 | + swap_kernel[grid](x, y, BLOCK_SIZE=1024) |
| 32 | + |
| 33 | + |
| 34 | +def test(device): |
| 35 | + torch.manual_seed(0) |
| 36 | + size = 10240 |
| 37 | + x = torch.rand(size, device=device) |
| 38 | + y = torch.rand(size, device=device) |
| 39 | + assert not torch.equal(x, y) |
| 40 | + x_ = x.clone() |
| 41 | + y_ = y.clone() |
| 42 | + swap(x, y) |
| 43 | + assert torch.equal(x, y_) |
| 44 | + assert torch.equal(y, x_) |
0 commit comments