Skip to content

Commit 8b9f5dd

Browse files
authored
Adding a test for swap kernel (#218)
Following up on conversation about copy optimization. --------- Co-authored-by: Renat Idrisov <[email protected]>
1 parent d1c5441 commit 8b9f5dd

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

python/examples/test_swap.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
# The purpose of this kernel and test is to catch incorrectly optimized kernels
7+
# where copy elimination happens erroneously in the absence of explicit memory allocation.
8+
# Such optimization bugs can result in incorrect behavior when swapping two arrays,
9+
# particularly when both arrays unintentionally end up with the same data due to
10+
# missing intermediate storage or mismanaged memory access.
11+
12+
@triton.jit
13+
def swap_kernel(
14+
x_ptr, # *Pointer* to first inout vector.
15+
y_ptr, # *Pointer* to second inout vector.
16+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
17+
# NOTE: `constexpr` so it can be used as a shape value.
18+
):
19+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
20+
block_start = pid * BLOCK_SIZE
21+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
22+
x = tl.load(x_ptr + offsets)
23+
y = tl.load(y_ptr + offsets)
24+
tl.store(x_ptr + offsets, y)
25+
tl.store(y_ptr + offsets, x)
26+
27+
28+
def swap(x: torch.Tensor, y: torch.Tensor):
29+
n_elements = x.numel()
30+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
31+
swap_kernel[grid](x, y, BLOCK_SIZE=1024)
32+
33+
34+
def test(device):
35+
torch.manual_seed(0)
36+
size = 10240
37+
x = torch.rand(size, device=device)
38+
y = torch.rand(size, device=device)
39+
assert not torch.equal(x, y)
40+
x_ = x.clone()
41+
y_ = y.clone()
42+
swap(x, y)
43+
assert torch.equal(x, y_)
44+
assert torch.equal(y, x_)

0 commit comments

Comments
 (0)