Skip to content

Commit 16a87b4

Browse files
authored
[Gluon] Use _convert_elem_to_ir_value in some APIs (#7022)
This allows passing constexprs as the arguments, like `tmem.subslice(0)`
1 parent 5b50c7f commit 16a87b4

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

python/test/gluon/test_frontend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,29 @@ def kernel():
541541
assert "order must be a permutation of 0..(rank-1), but was [1]" in str(e.value.__cause__)
542542

543543

544+
@gluon.jit
545+
def tmem_subslice_kernel():
546+
layout: ttgl.constexpr = ttgl.nvidia.blackwell.TensorMemoryLayout(block=[128, 128], unpacked=True)
547+
tmem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, [2, 256, 256], layout)
548+
tmem.subslice(0)
549+
550+
551+
def test_tmem_subslice_constexpr():
552+
expecttest.assert_expected_inline(
553+
run_parser(tmem_subslice_kernel).str_nodebug(), """\
554+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
555+
module {
556+
tt.func public @tmem_subslice_kernel() attributes {noinline = false} {
557+
%result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
558+
%c0_i32 = arith.constant 0 : i32
559+
%c0_i32_0 = arith.constant 0 : i32
560+
%0 = ttg.memdesc_subview %result[%c0_i32, %c0_i32_0, %c0_i32_0] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256>
561+
tt.return
562+
}
563+
}
564+
""")
565+
566+
544567
@gluon.jit
545568
def smem_and_layout_user(smem, a: ttgl.constexpr):
546569
pass

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def memdesc_slice(mem_desc, index, shape, layout, builder: GluonOpBuilder):
7171
assert mem_desc.rank > len(shape), f"source rank ({mem_desc.rank}) must be greater than result rank ({len(shape)})"
7272

7373
offsets = [builder.get_int32(0)] * mem_desc.rank
74-
offsets[0] = index.handle
74+
offsets[0] = tl_semantic._convert_elem_to_ir_value(builder, index, require_i64=False)
7575
return _memdesc_subview(mem_desc, offsets, shape, layout, builder)
7676

7777

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Tuple, List, TYPE_CHECKING
33

44
from dataclasses import dataclass
5+
from triton.language.semantic import _convert_elem_to_ir_value, _convert_to_ir_values
56
from triton.experimental.gluon.language import _core as ttgl
67
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
78

@@ -140,12 +141,12 @@ def subslice(self, index, shape=None, layout=None, _builder: GluonOpBuilder = No
140141
if shape is None:
141142
shape = self.shape[1:]
142143

143-
index = _unwrap_if_constexpr(index)
144+
index = _convert_elem_to_ir_value(_builder, index, require_i64=False)
144145
shape = [_unwrap_if_constexpr(s) for s in shape]
145146
layout = _unwrap_if_constexpr(layout)
146147

147148
offsets = [_builder.get_int32(0)] * self.rank
148-
offsets[0] = index.handle
149+
offsets[0] = index
149150
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
150151
ret.handle = _builder.create_memdesc_subview(ret.type.to_ir(_builder), self.handle, offsets)
151152
return ret
@@ -178,6 +179,6 @@ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_
178179
true = ttgl.to_tensor(True, _builder=_builder)
179180
mbarrier_preds = [true] * len(mbarriers)
180181
else:
181-
mbarrier_preds = [pred.handle for pred in mbarrier_preds]
182+
mbarrier_preds = _convert_to_ir_values(_builder, mbarrier_preds, require_i64=False)
182183

183184
_builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers, mbarrier_preds)

python/triton/language/semantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,8 @@ def _convert_elem_to_ir_value(builder, elem, require_i64):
18651865
if isinstance(elem, int):
18661866
elem = tl.constexpr(elem)
18671867
if isinstance(elem, tl.constexpr):
1868+
if isinstance(elem.value, bool):
1869+
return builder.get_int1(elem.value)
18681870
if require_i64:
18691871
assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
18701872
f"got a value {elem.value} which is out of the range"

0 commit comments

Comments
 (0)