Skip to content

Commit 2f81756

Browse files
authored
[Gluon] Rename memdesc subview ops (#7165)
To be more consistent with pytorch, rename: - subslice -> index - split -> slice Technically the inputs to a slice should be start and end, instead of start and length. But we need to know the length statically in order to create the types, so this isn't possible unfortunately.
1 parent c80eef1 commit 2f81756

File tree

4 files changed

+50
-63
lines changed

4 files changed

+50
-63
lines changed

python/test/gluon/test_frontend.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
109109
mem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, a.shape, tmem_layout, a)
110110
b = mem.load(layout) # noqa: F841
111111
mem.store(a)
112-
slice1 = mem.split(0, YBLOCK // 2) # noqa: F841
113-
slice2 = mem.split(YBLOCK // 2, YBLOCK // 2) # noqa: F841
112+
slice1 = mem.slice(0, YBLOCK // 2) # noqa: F841
113+
slice2 = mem.slice(YBLOCK // 2, YBLOCK // 2) # noqa: F841
114114

115115
buffers = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.float32, [2, XBLOCK, YBLOCK], tmem_layout)
116116
for i in range(2):
117-
buffers.subslice(i).load(layout)
117+
buffers.index(i).load(layout)
118118

119119

120120
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10,
@@ -165,9 +165,9 @@ def test_tensor_memory(fresh_knobs):
165165
def shared_memory_subview_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, smem_layout: ttgl.constexpr):
166166
XHALF: ttgl.constexpr = XBLOCK // 2
167167
smem = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK, XBLOCK], smem_layout)
168-
view = smem.split(XHALF, XHALF, dim=1)
168+
view = smem.slice(XHALF, XHALF, dim=1)
169169
value = view.load(layout)
170-
view = smem.split(XHALF, XHALF, dim=0)
170+
view = smem.slice(XHALF, XHALF, dim=0)
171171
view.store(value.trans())
172172

173173

@@ -203,25 +203,25 @@ def test_shared_memory_subview(fresh_knobs):
203203

204204

205205
@gluon.jit
206-
def shared_memory_subslice_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, smem_layout: ttgl.constexpr):
206+
def shared_memory_index_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, smem_layout: ttgl.constexpr):
207207
smem = ttgl.allocate_shared_memory(ttgl.int32, [4, XBLOCK], smem_layout)
208208
for i in range(4):
209-
smem.subslice(i).load(layout)
209+
smem.index(i).load(layout)
210210

211211

212-
def test_shared_memory_subslice(fresh_knobs):
212+
def test_shared_memory_index(fresh_knobs):
213213
knobs.compilation.disable_line_info = True
214214

215215
layout = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0])
216216
smem_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2)
217-
h = shared_memory_subslice_kernel.warmup(256, layout, smem_layout, num_warps=4, grid=(1, ))
217+
h = shared_memory_index_kernel.warmup(256, layout, smem_layout, num_warps=4, grid=(1, ))
218218
expecttest.assert_expected_inline(
219219
anonymize_ir(h.asm["source"]), """\
220220
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
221221
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
222222
#smem = #ttg.shared_memory
223223
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
224-
tt.func public @shared_memory_subslice_kernel() attributes {noinline = false} {
224+
tt.func public @shared_memory_index_kernel() attributes {noinline = false} {
225225
%0 = ttg.local_alloc : () -> !ttg.memdesc<4x256xi32, #shared, #smem, mutable> loc(#loc)
226226
%c0_i32 = arith.constant 0 : i32 loc(#loc)
227227
%c4_i32 = arith.constant 4 : i32 loc(#loc)
@@ -250,7 +250,7 @@ def shared_memory_cast_kernel():
250250
rank=2, ctas_per_cga=[1, 1], cta_split_num=[1,
251251
1], cta_order=[1, 0])
252252
smem = ttgl.allocate_shared_memory(ttgl.int8, [2, 256, 128], layout_a)
253-
perm = smem.subslice(0).permute((1, 0))
253+
perm = smem.index(0).permute((1, 0))
254254
ttgl.static_assert(perm.type.layout == layout_T)
255255

256256
layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16,
@@ -562,18 +562,18 @@ def kernel():
562562

563563

564564
@gluon.jit
565-
def tmem_subslice_kernel():
565+
def tmem_index_kernel():
566566
layout: ttgl.constexpr = TensorMemoryLayout(block=[128, 128], unpacked=True)
567567
tmem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, [2, 256, 256], layout)
568-
tmem.subslice(0)
568+
tmem.index(0)
569569

570570

571-
def test_tmem_subslice_constexpr():
571+
def test_tmem_index_constexpr():
572572
expecttest.assert_expected_inline(
573-
anonymize_ir(run_parser(tmem_subslice_kernel).str_nodebug()), """\
573+
anonymize_ir(run_parser(tmem_index_kernel).str_nodebug()), """\
574574
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
575575
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
576-
tt.func public @tmem_subslice_kernel() attributes {noinline = false} {
576+
tt.func public @tmem_index_kernel() attributes {noinline = false} {
577577
%result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
578578
%c0_i32 = arith.constant 0 : i32
579579
%c0_i32_0 = arith.constant 0 : i32

python/triton/experimental/gluon/language/_core.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ def shape(self):
206206
def rank(self):
207207
return len(self.shape)
208208

209+
@property
210+
def layout(self):
211+
return self.type.layout
212+
209213
def __str__(self) -> str:
210214
return str(self.type)
211215

@@ -219,31 +223,16 @@ def store(self, value, _semantic: GluonSemantic) -> None:
219223
return _semantic.shared_store(self, value)
220224

221225
@builtin
222-
def split(self, offset, size, dim=None, layout=None, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
223-
if layout is None:
224-
layout = self.type.layout
225-
if dim is None:
226-
dim = 0
227-
228-
offset = _unwrap_if_constexpr(offset)
229-
size = _unwrap_if_constexpr(size)
226+
def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
227+
start = _unwrap_if_constexpr(start)
228+
length = _unwrap_if_constexpr(length)
230229
dim = _unwrap_if_constexpr(dim)
231-
layout = _unwrap_if_constexpr(layout)
232-
233-
return _semantic.memdesc_split(self, offset, size, dim, layout)
230+
return _semantic.memdesc_slice(self, start, length, dim)
234231

235232
@builtin
236-
def subslice(self, index, shape=None, layout=None, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
237-
if layout is None:
238-
layout = self.type.layout
239-
if shape is None:
240-
shape = self.shape[1:]
241-
233+
def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
242234
index = _unwrap_if_constexpr(index)
243-
shape = [_unwrap_if_constexpr(s) for s in shape]
244-
layout = _unwrap_if_constexpr(layout)
245-
246-
return _semantic.memdesc_slice(self, index, shape, layout)
235+
return _semantic.memdesc_index(self, index)
247236

248237
@builtin
249238
def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,26 +160,25 @@ def shared_store(self, mem_desc, value):
160160
def shared_dealloc(self, mem_desc):
161161
self.builder.create_local_dealloc(mem_desc.handle)
162162

163-
def _memdesc_subview(self, mem_desc, offsets, shape, layout):
163+
def _memdesc_subview(self, mem_desc, offsets, shape):
164+
layout = mem_desc.layout
164165
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)
165166
builder = self.builder
166167
handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets)
167168
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
168169

169-
def memdesc_split(self, mem_desc, offset, size, dim, layout):
170+
def memdesc_slice(self, mem_desc, start, length, dim):
170171
offsets = [self.builder.get_int32(0)] * mem_desc.rank
171-
offsets[dim] = self.builder.get_int32(offset)
172+
offsets[dim] = self.to_tensor(start).handle
172173
shape = list(mem_desc.shape)
173-
shape[dim] = size
174-
return self._memdesc_subview(mem_desc, offsets, shape, layout)
175-
176-
def memdesc_slice(self, mem_desc, index, shape, layout):
177-
assert mem_desc.rank > len(
178-
shape), f"source rank ({mem_desc.rank}) must be greater than result rank ({len(shape)})"
174+
shape[dim] = length
175+
return self._memdesc_subview(mem_desc, offsets, shape)
179176

177+
def memdesc_index(self, mem_desc, index):
178+
shape = mem_desc.shape[1:]
180179
offsets = [self.builder.get_int32(0)] * mem_desc.rank
181-
offsets[0] = self._convert_elem_to_ir_value(index, require_i64=False)
182-
return self._memdesc_subview(mem_desc, offsets, shape, layout)
180+
offsets[0] = self.to_tensor(index).handle
181+
return self._memdesc_subview(mem_desc, offsets, shape)
183182

184183
def memdesc_trans(self, mem_desc, order):
185184
assert len(order) == len(

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass
55
from triton.experimental.gluon.language import _core as ttgl
66
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
7+
from triton.experimental.gluon.language._semantic import _check
78

89
from . import tma
910
from ..hopper import mbarrier, fence_async_shared
@@ -108,6 +109,10 @@ def shape(self):
108109
def rank(self):
109110
return len(self.shape)
110111

112+
@property
113+
def layout(self):
114+
return self.type.layout
115+
111116
def __str__(self) -> str:
112117
return str(self.type)
113118

@@ -126,12 +131,12 @@ def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
126131
_semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
127132

128133
@builtin
129-
def split(self, start, length, _semantic: GluonSemantic) -> None:
134+
def slice(self, start, length, _semantic: GluonSemantic) -> None:
130135
start = _unwrap_if_constexpr(start)
131136
length = _unwrap_if_constexpr(length)
132-
assert isinstance(start, int)
133-
assert isinstance(length, int)
134-
shape = [self.shape[0], length]
137+
_check(isinstance(start, int), lambda: "start must be a constant int")
138+
_check(isinstance(length, int), lambda: "length must be a constant int")
139+
shape = self.shape[:-1] + [length]
135140
layout = self.type.layout
136141
layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked,
137142
layout.cta_split_num)
@@ -141,19 +146,13 @@ def split(self, start, length, _semantic: GluonSemantic) -> None:
141146
return ret
142147

143148
@builtin
144-
def subslice(self, index, shape=None, layout=None, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
145-
if layout is None:
146-
layout = self.type.layout
147-
if shape is None:
148-
shape = self.shape[1:]
149-
150-
index = _semantic._convert_elem_to_ir_value(index, require_i64=False)
151-
shape = [_unwrap_if_constexpr(s) for s in shape]
152-
layout = _unwrap_if_constexpr(layout)
153-
149+
def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
150+
index = _semantic.to_tensor(index)
154151
builder = _semantic.builder
155152
offsets = [builder.get_int32(0)] * self.rank
156-
offsets[0] = index
153+
offsets[0] = index.handle
154+
shape = self.shape[1:]
155+
layout = self.layout
157156
ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape)
158157
ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets)
159158
return ret

0 commit comments

Comments
 (0)