Skip to content

Commit 9db5c5f

Browse files
authored
[BC Breaking] Make tl.ravel keep element orders by default (#5743)
This doesn't break functional backward compatiblity as the new semantic is a subset of the what was allowed before but it would break performance backward compatiblity. The makes it less error prone.
1 parent a5235d4 commit 9db5c5f

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

python/test/unit/language/test_standard.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr):
9898
torch.testing.assert_close(expect, actual)
9999

100100

101+
@pytest.mark.interpreter
102+
def test_ravel(device):
103+
104+
@triton.jit
105+
def triton_ravel(out_ptr):
106+
a = tl.arange(0, 256)
107+
a = tl.reshape(a, (32, 8))
108+
a = tl.ravel(a)
109+
tl.store(out_ptr + tl.arange(0, 256), a)
110+
111+
out = torch.empty((256, ), device=device, dtype=torch.int32)
112+
triton_ravel[(1, )](out)
113+
114+
assert (out == torch.arange(0, 256, device=device)).all()
115+
116+
101117
@pytest.mark.interpreter
102118
@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]])
103119
def test_swizzle2d(size_i, size_j, size_g, device):

python/triton/language/standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def softmax(x, ieee_rounding=False):
5959

6060
@core._tensor_member_fn
6161
@jit
62-
def ravel(x):
62+
def ravel(x, can_reorder=False):
6363
"""
6464
Returns a contiguous flattened view of :code:`x`.
6565
6666
:param x: the input tensor
6767
:type x: Block
6868
"""
69-
return core.reshape(x, [x.numel], can_reorder=True)
69+
return core.reshape(x, [x.numel], can_reorder=can_reorder)
7070

7171

7272
@jit

0 commit comments

Comments
 (0)