Skip to content

Commit 7d3bf12

Browse files
authored
[Gluon] Broadcast auto with concrete layouts (#7491)
This also makes `arange` return auto layout by default, so you can do for example: ```python xidx = gl.arange(0, 32)[:, None] yidx = gl.arange(0, 16)[None, :] off = xidx * x_stride + yidx * y_stride off = gl.convert_layout(off, concrete_layout) ``` I also fixed `filecheck_test` to disable the overflow sanitizer, as it can cause lit tests to fail if the check statements match with the overflow sanitizer ops.
1 parent 041ec1b commit 7d3bf12

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

python/test/gluon/test_frontend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,3 +1233,28 @@ def test_auto_layout():
12331233
z = x + y
12341234
# CHECK: (tensor<16x8xi32, #gluon.auto_encoding>) -> tensor<16xi32, #gluon.auto_encoding
12351235
ttgl.sum(z, axis=1)
1236+
1237+
# CHECK: tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #gluon.auto_encoding>
1238+
ttgl.arange(0, 32)
1239+
1240+
1241+
@filecheck_test
1242+
@gluon.jit
1243+
def test_auto_layout_broadcast():
1244+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked
1245+
# CHECK: [[X:%.*]] = arith.constant dense<1> : tensor<16x1xi32, #gluon.auto_encoding>
1246+
# CHECK: [[Y:%.*]] = arith.constant dense<2> : tensor<1x16xi32, [[BLOCKED]]>
1247+
x = ttgl.full([16, 1], 1, ttgl.int32, layout=ttgl.AutoLayout())
1248+
y = ttgl.full([1, 16], 2, ttgl.int32, layout=ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]))
1249+
1250+
# CHECK: [[XCVT:%.*]] = ttg.convert_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
1251+
# CHECK: [[XBCAST:%.*]] = tt.broadcast [[XCVT]]
1252+
# CHECK: [[YBCAST:%.*]] = tt.broadcast [[Y]]
1253+
# CHECK: arith.addi [[XBCAST]], [[YBCAST]] : tensor<16x16xi32, [[BLOCKED]]>
1254+
_ = x + y
1255+
1256+
# CHECK: [[XCVT2:%.*]] = ttg.convert_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
1257+
# CHECK: [[YBCAST2:%.*]] = tt.broadcast [[Y]]
1258+
# CHECK: [[XBCAST2:%.*]] = tt.broadcast [[XCVT2]]
1259+
# CHECK: arith.muli [[YBCAST2]], [[XBCAST2]] : tensor<16x16xi32, [[BLOCKED]]>
1260+
_ = y * x

python/triton/_filecheck.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def run_filecheck(name, module_str, check_template):
4242
temp.write(check_template)
4343

4444
try:
45-
subprocess.check_output([filecheck_path, temp_expected, "--input-file", temp_module],
46-
stderr=subprocess.STDOUT)
45+
subprocess.check_output(
46+
[filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
47+
stderr=subprocess.STDOUT)
4748
except subprocess.CalledProcessError as error:
4849
decoded = error.output.decode('unicode_escape')
4950
raise ValueError(decoded)
@@ -60,8 +61,10 @@ def run_parser(kernel_fn):
6061
ir.load_dialects(context)
6162
stub_backend.load_dialects(context)
6263

63-
extra_options = src.parse_options()
64-
options = stub_backend.parse_options(dict(**extra_options))
64+
options = dict(sanitize_overflow=False)
65+
options.update(src.parse_options())
66+
67+
options = stub_backend.parse_options(options)
6568
codegen_fns = stub_backend.get_codegen_implementation(options)
6669
module_map = stub_backend.get_module_map()
6770
module = src.make_ir(options, codegen_fns, module_map, context)

python/triton/experimental/gluon/language/_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444

4545
_IMPORT_FROM_TRITON: List[str] = [
46+
"broadcast",
4647
"expand_dims",
4748
"inline_asm_elementwise",
4849
"join",
@@ -341,14 +342,14 @@ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
341342

342343

343344
@builtin
344-
def arange(start, end, layout, _semantic=None):
345+
def arange(start, end, layout=None, _semantic=None):
345346
"""
346347
Generate a sequence tensor with values in [start, end) using a specified layout.
347348
348349
Args:
349350
start (int): Inclusive start of the sequence.
350351
end (int): Exclusive end of the sequence.
351-
layout (DistributedLayout): The layout of the output tensor.
352+
layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout.
352353
353354
Returns:
354355
tensor: A 1D tensor containing sequential values.

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,14 @@ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
112112
lhs_shape = lhs_ty.get_block_shapes()
113113
rhs_shape = rhs_ty.get_block_shapes()
114114
ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
115-
if lhs_ty.layout != rhs_ty.layout:
115+
116+
is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
117+
is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
118+
if is_lhs_auto and not is_rhs_auto:
119+
lhs = self.convert_layout(lhs, rhs_ty.layout)
120+
elif is_rhs_auto and not is_lhs_auto:
121+
rhs = self.convert_layout(rhs, lhs_ty.layout)
122+
elif lhs_ty.layout != rhs_ty.layout:
116123
raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
117124

118125
lhs = self.broadcast_impl_shape(lhs, ret_shape)
@@ -121,6 +128,8 @@ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
121128

122129
def arange(self, start, end, layout):
123130
shape = [end - start]
131+
if layout is None:
132+
layout = AutoLayout()
124133
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
125134
return super().arange(start, end, ret_ty=ret_ty)
126135

@@ -138,7 +147,7 @@ def full(self, shape, value, dtype, layout):
138147
scalar = self.make_scalar(value, dtype)
139148
return self.splat(scalar, shape, layout)
140149

141-
def convert_layout(self, value, layout, assert_trivial):
150+
def convert_layout(self, value, layout, assert_trivial=False):
142151
ty = value.type
143152
_check(isinstance(ty, ttgl.distributed_type),
144153
lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")

0 commit comments

Comments
 (0)