Skip to content

Commit 8d99aa1

Browse files
[FRONTEND] Allow arbitrary dim in tl.flip (#6853)
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> Currently, `tl.flip` only supports flipping along the last dimension. We can still use the same algorithm to flip any dimension as long as we tweak where the swaps begin/end. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 2509898 commit 8d99aa1

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

python/test/unit/language/test_standard.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,26 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
6565

6666

6767
@pytest.mark.interpreter
68-
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
68+
@pytest.mark.parametrize("M, N, K", [[1, 16, 64], [8, 2, 256], [32, 1, 2], [128, 8, 1]])
6969
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
70-
def test_flip(M, N, dtype_str, device):
70+
@pytest.mark.parametrize("dim", [0, 1, 2, -2])
71+
def test_flip(M, N, K, dtype_str, dim, device):
7172

7273
@triton.jit
73-
def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr):
74-
offx = tl.arange(0, M)
75-
offy = tl.arange(0, N) * M
76-
off2d = offx[None, :] + offy[:, None]
77-
x = tl.load(X + off2d)
78-
x = tl.flip(x)
79-
tl.store(Z + off2d, x)
80-
81-
x = numpy_random((N, M), dtype_str=dtype_str)
74+
def flip_kernel(X, Z, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, dim: tl.constexpr):
75+
offx = tl.arange(0, M) * N * K
76+
offy = tl.arange(0, N) * K
77+
offz = tl.arange(0, K)
78+
off3d = offx[:, None, None] + offy[None, :, None] + offz[None, None, :]
79+
x = tl.load(X + off3d)
80+
x = tl.flip(x, dim)
81+
tl.store(Z + off3d, x)
82+
83+
x = numpy_random((M, N, K), dtype_str=dtype_str)
8284
x = torch.from_numpy(x).to(device)
83-
y = torch.flip(x, (1, ))
85+
y = torch.flip(x, (dim, ))
8486
z = torch.empty_like(x, device=device)
85-
flip_kernel[(1, )](x, z, N, M, num_warps=8)
87+
flip_kernel[(1, )](x, z, M, N, K, dim, num_warps=8)
8688
assert (y == z).all(), (y, z)
8789

8890

python/triton/language/standard.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,8 @@ def _get_flip_dim(dim, shape):
475475
shape = core._unwrap_if_constexpr(shape)
476476
if dim is None:
477477
dim = len(shape) - 1
478-
assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
478+
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
479+
dim += len(shape)
479480
return core.constexpr(dim)
480481

481482

@@ -487,20 +488,19 @@ def flip(x, dim=None):
487488
488489
:param x: the first input tensor
489490
:type x: Block
490-
:param dim: the dimension to flip along (currently only final dimension supported)
491+
:param dim: the dimension to flip along
491492
:type dim: int
492493
"""
493-
core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
494-
core.static_assert(_is_power_of_two(x.numel))
495-
# reshape the tensor to have all dimensions be 2.
496-
# TODO: We shouldn't have to change the dimensions not sorted.
497-
steps: core.constexpr = _log2(x.numel)
498-
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
494+
core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
495+
_dim: core.constexpr = _get_flip_dim(dim, x.shape)
496+
core.static_assert(_is_power_of_two(x.shape[_dim]))
497+
steps: core.constexpr = _log2(x.shape[_dim])
499498

499+
# reshape the swap dimension to (2, 2, ..., 2)
500500
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
501-
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
502-
for i in core.static_range(start, steps):
503-
y = y ^ xor_sum(y, i, True)
501+
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + core.tuple([2] * steps) + x.shape[_dim + 1:])
502+
for i in core.static_range(steps):
503+
y = y ^ xor_sum(y, _dim + i, True)
504504
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
505505
return x
506506

0 commit comments

Comments
 (0)