Skip to content

Commit d429e2b

Browse files
authored
[Gluon] Fix auto layout inconsistencies (#7726)
PR #7646 changed `gl.full` to default to auto layout but not `gl.zeros`. PR #7589 changed `gl.convert_layout` to set the layout on auto layouts, which breaks code such as: ```python a = ... # some auto layout gl.set_auto_layout(a, compute_layout) gl.store(..., gl.convert_layout(a, store_layout)) ```
1 parent 1bd811a commit d429e2b

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

python/test/gluon/test_frontend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def kernel(src_layout: ttgl.constexpr, dst_layout: ttgl.constexpr):
104104
assert "layout conversion from BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
105105
assert "to AutoLayout() is not trivial" in str(e.value.__cause__)
106106

107+
with pytest.raises(CompilationError) as e:
108+
src_layout: ttgl.constexpr = ttgl.AutoLayout()
109+
dst_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
110+
kernel.warmup(src_layout, dst_layout, grid=(1, ))
111+
112+
assert "layout conversion from AutoLayout()" in str(e.value.__cause__)
113+
assert "to BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
114+
assert "is not trivial" in str(e.value.__cause__)
115+
107116

108117
@gluon.jit
109118
def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr,

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from triton._C.libtriton.gluon_ir import GluonOpBuilder
88
from ._semantic import GluonSemantic
99

10-
from ._layouts import SharedLayout, DistributedLayout, AutoLayout
10+
from ._layouts import SharedLayout, DistributedLayout
1111
from triton._C.libtriton import ir
1212
import triton.language.core as tl_core
1313
from triton.language.core import (
@@ -383,8 +383,6 @@ def convert_layout(value, layout, assert_trivial=False, _semantic=None):
383383
tensor: The tensor with the new layout.
384384
"""
385385
layout = _unwrap_if_constexpr(layout)
386-
if isinstance(value.type.layout, AutoLayout):
387-
return set_auto_layout(value, layout, _semantic=_semantic)
388386
return _semantic.convert_layout(value, layout, assert_trivial)
389387

390388

@@ -397,7 +395,7 @@ def full(shape, value, dtype, layout=None, _semantic=None):
397395
shape (Sequence[int]): The shape of the tensor.
398396
value (int or float): The fill value.
399397
dtype (dtype): The data type for the tensor.
400-
layout (DistributedLayout): The layout of the output tensor.
398+
layout (Optional[DistributedLayout]): The layout of the output tensor, defaults to AutoLayout().
401399
402400
Returns:
403401
tensor: A tensor where every element equals value.

python/triton/experimental/gluon/language/_standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@
3232

3333

3434
@jit
35-
def zeros(shape, dtype, layout):
35+
def zeros(shape, dtype, layout=None):
3636
"""
3737
Create a tensor filled with zeros.
3838
3939
Args:
4040
shape (Sequence[int]): The shape of the tensor.
4141
dtype (dtype): The data type for the tensor.
42-
layout (DistributedLayout): The distributed layout of the tensor.
42+
layout (Optional[DistributedLayout]): The distributed layout of the tensor, defaults to AutoLayout().
4343
4444
Returns:
4545
tensor: A tensor where every element is zero.

0 commit comments

Comments
 (0)