Skip to content

Commit cf80141

Browse files
[GLUON] Docstrings for public functions (#7323)
Adding documentation for public gluon API. Generated by codex, with proof-reading and editing some of the issues by me. Formatting differs from triton docstrings, (using Google style rather than reST), but subjectively this seems more readable. Changing to reST formatting should be relatively easy if people have strong opinions about this. --------- Co-authored-by: peterbell10 <[email protected]>
1 parent 3043f5e commit cf80141

File tree

9 files changed

+469
-3
lines changed

9 files changed

+469
-3
lines changed

python/triton/experimental/gluon/language/_core.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def mangle(self) -> str:
191191

192192

193193
class shared_memory_descriptor(base_value):
194+
"""
195+
Represents a handle to a shared memory allocation in Gluon IR.
196+
"""
194197

195198
def __init__(self, handle, element_ty, shape, layout, alloc_shape):
196199
self.handle = handle
@@ -220,39 +223,104 @@ def __str__(self) -> str:
220223

221224
@builtin
222225
def load(self, layout, _semantic: GluonSemantic) -> tensor:
226+
"""
227+
Load a tensor from shared memory.
228+
229+
Args:
230+
layout (DistributedLayout): The destination layout of the tensor.
231+
232+
Returns:
233+
tensor: A Gluon tensor containing the loaded data.
234+
"""
223235
layout = _unwrap_if_constexpr(layout)
224236
return _semantic.shared_load(self, layout)
225237

226238
@builtin
227239
def store(self, value, _semantic: GluonSemantic) -> None:
240+
"""
241+
Store a tensor into shared memory.
242+
243+
Args:
244+
value (tensor): The tensor whose contents to store.
245+
"""
228246
return _semantic.shared_store(self, value)
229247

230248
@builtin
231249
def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
250+
"""
251+
Create a subview of shared memory by slicing along a given dimension.
252+
253+
Args:
254+
start (int): The starting index of the slice.
255+
length (int): The length of the slice.
256+
dim (int): The dimension to slice (default: 0).
257+
258+
Returns:
259+
shared_memory_descriptor: Descriptor for the sliced subview.
260+
"""
232261
start = _unwrap_if_constexpr(start)
233262
length = _unwrap_if_constexpr(length)
234263
dim = _unwrap_if_constexpr(dim)
235264
return _semantic.memdesc_slice(self, start, length, dim)
236265

237266
@builtin
238267
def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
268+
"""
269+
Create a subview of shared memory by indexing along the first dimension.
270+
271+
Args:
272+
index (int): The index at which to take the subview.
273+
274+
Returns:
275+
shared_memory_descriptor: Descriptor for the indexed subview.
276+
"""
239277
index = _unwrap_if_constexpr(index)
240278
return _semantic.memdesc_index(self, index)
241279

242280
@builtin
243281
def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
282+
"""
283+
Permute the dimensions of the shared memory descriptor.
284+
285+
Args:
286+
order (List[int]): The new ordering of dimensions.
287+
288+
Returns:
289+
shared_memory_descriptor: Descriptor with permuted dimensions.
290+
"""
244291
order = [_unwrap_if_constexpr(o) for o in order]
245292
return _semantic.memdesc_trans(self, order)
246293

247294
@builtin
248295
def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
296+
"""
297+
Reshape the shared memory descriptor to a new shape and layout.
298+
299+
Args:
300+
shape (List[int]): The target shape.
301+
layout (SharedLayout): The new layout for the descriptor.
302+
303+
Returns:
304+
shared_memory_descriptor: Descriptor with the new shape and layout.
305+
"""
249306
shape = [_unwrap_if_constexpr(s) for s in shape]
250307
layout = _unwrap_if_constexpr(layout)
251308

252309
return _semantic.memdesc_reshape(self, shape, layout)
253310

254311
@builtin
255312
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
313+
"""
314+
Reinterpret the shared memory descriptor as a different dtype, shape, or layout.
315+
316+
Args:
317+
dtype (dtype): The new data type.
318+
shape (List[int]): The new shape.
319+
layout (SharedLayout): The new layout.
320+
321+
Returns:
322+
shared_memory_descriptor: Descriptor with updated type and layout.
323+
"""
256324
dtype = _unwrap_if_constexpr(dtype)
257325
shape = [_unwrap_if_constexpr(s) for s in shape]
258326
layout = _unwrap_if_constexpr(layout)
@@ -261,6 +329,9 @@ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) ->
261329

262330
@builtin
263331
def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
332+
"""
333+
Dummy use to keep the shared memory descriptor alive.
334+
"""
264335
return _semantic.shared_dealloc(self)
265336

266337

@@ -271,6 +342,17 @@ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
271342

272343
@builtin
273344
def arange(start, end, layout, _semantic=None):
345+
"""
346+
Generate a sequence tensor with values in [start, end) using a specified layout.
347+
348+
Args:
349+
start (int): Inclusive start of the sequence.
350+
end (int): Exclusive end of the sequence.
351+
layout (DistributedLayout): The layout of the output tensor.
352+
353+
Returns:
354+
tensor: A 1D tensor containing sequential values.
355+
"""
274356
start = _unwrap_if_constexpr(start)
275357
end = _unwrap_if_constexpr(end)
276358
layout = _unwrap_if_constexpr(layout)
@@ -279,12 +361,34 @@ def arange(start, end, layout, _semantic=None):
279361

280362
@builtin
281363
def convert_layout(value, layout, _semantic=None):
364+
"""
365+
Convert a tensor to a different distributed layout.
366+
367+
Args:
368+
value (tensor): The input tensor.
369+
layout (DistributedLayout): The target layout.
370+
371+
Returns:
372+
tensor: The tensor with the new layout.
373+
"""
282374
layout = _unwrap_if_constexpr(layout)
283375
return _semantic.convert_layout(value, layout)
284376

285377

286378
@builtin
287379
def full(shape, value, dtype, layout, _semantic=None):
380+
"""
381+
Create a tensor filled with a scalar value, with specified shape, dtype, and layout.
382+
383+
Args:
384+
shape (Sequence[int]): The shape of the tensor.
385+
value (int or float): The fill value.
386+
dtype (dtype): The data type for the tensor.
387+
layout (DistributedLayout): The layout of the output tensor.
388+
389+
Returns:
390+
tensor: A tensor where every element equals value.
391+
"""
288392
shape = _unwrap_shape(shape)
289393
value = _unwrap_if_constexpr(value)
290394
dtype = _unwrap_if_constexpr(dtype)
@@ -294,6 +398,18 @@ def full(shape, value, dtype, layout, _semantic=None):
294398

295399
@builtin
296400
def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None):
401+
"""
402+
Allocate shared memory for a tensor with the given element type, shape, and layout.
403+
404+
Args:
405+
element_ty (dtype): The element data type.
406+
shape (Sequence[int]): The dimensions of the shared memory.
407+
layout (SharedLayout): The shared memory layout.
408+
value (tensor, optional): Initial value to copy into shared memory.
409+
410+
Returns:
411+
shared_memory_descriptor: Descriptor for the allocated memory.
412+
"""
297413
element_ty = _unwrap_if_constexpr(element_ty)
298414
shape = _unwrap_if_constexpr(shape)
299415
shape = [_unwrap_if_constexpr(s) for s in shape]
@@ -304,6 +420,20 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
304420
@builtin
305421
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
306422
_semantic=None, _generator=None):
423+
"""
424+
Create a warp-specialized execution region, partitioning work across warps.
425+
426+
Args:
427+
default_args (List[Any]): Arguments for the default region.
428+
default_partition (callable): Function to build the default execution region.
429+
worker_args (List[Any]): Arguments for each warp partition.
430+
worker_partitions (List[callable]): Functions for each warp partition.
431+
worker_num_warps (List[int]): Number of warps per partition.
432+
worker_num_regs (List[int]): Number of registers per partition.
433+
434+
Returns:
435+
Tuple[Any, ...]: Results from the default region.
436+
"""
307437
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
308438
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
309439
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
@@ -312,4 +442,7 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
312442

313443
@builtin
314444
def thread_barrier(_semantic=None):
445+
"""
446+
Insert a barrier to synchronize threads within a CTA.
447+
"""
315448
return _semantic.debug_barrier()

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,26 @@ def _realize_cta_layout(layout, rank):
2222

2323

2424
class DistributedLayout:
25+
"""
26+
Base class for distributed memory layouts in Gluon IR.
27+
"""
2528
pass
2629

2730

2831
@dataclass(frozen=True)
2932
class BlockedLayout(DistributedLayout):
33+
"""
34+
Represents a blocked layout, partitioning a tensor across threads, warps, and CTAs.
35+
36+
Args:
37+
size_per_thread (List[int]): Number of elements per thread per dimension.
38+
threads_per_warp (List[int]): Number of threads per warp per dimension.
39+
warps_per_cta (List[int]): Number of warps per CTA per dimension.
40+
order (List[int]): The ordering of dimensions for partitioning.
41+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
42+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
43+
cta_order (Optional[List[int]]): Ordering for CTAs.
44+
"""
3045
size_per_thread: List[int]
3146
threads_per_warp: List[int]
3247
warps_per_cta: List[int]
@@ -83,6 +98,13 @@ def stringify(x):
8398

8499
@dataclass(frozen=True)
85100
class SliceLayout(DistributedLayout):
101+
"""
102+
Represents a layout corresponding to slicing a distributed tensor along one dimension.
103+
104+
Args:
105+
dim (int): The dimension index to slice.
106+
parent (DistributedLayout): The parent layout before slicing.
107+
"""
86108
dim: int
87109
parent: DistributedLayout
88110

@@ -102,6 +124,17 @@ def mangle(self) -> str:
102124

103125
@dataclass(frozen=True)
104126
class DistributedLinearLayout(DistributedLayout):
127+
"""
128+
Represents a linear distributed layout with explicit bases at register, lane, warp, and block levels.
129+
See: https://arxiv.org/abs/2505.23819 for reference.
130+
131+
Args:
132+
reg_bases (List[List[int]]): Bases for register-level distribution.
133+
lane_bases (List[List[int]]): Bases for lane-level distribution.
134+
warp_bases (List[List[int]]): Bases for warp-level distribution.
135+
block_bases (List[List[int]]): Bases for block-level distribution.
136+
shape (List[int]): The tensor global shape.
137+
"""
105138
reg_bases: List[List[int]]
106139
lane_bases: List[List[int]]
107140
warp_bases: List[List[int]]
@@ -136,6 +169,17 @@ def mangle(self):
136169

137170
@dataclass(frozen=True)
138171
class NVMMADistributedLayout(DistributedLayout):
172+
"""
173+
Represents a layout for NVIDIA MMA (tensor core) operations.
174+
175+
Args:
176+
version (List[int]): Version identifier for the MMA instruction.
177+
warps_per_cta (List[int]): Number of warps per CTA.
178+
instr_shape (List[int]): Instruction shape for MMA.
179+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
180+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
181+
cta_order (Optional[List[int]]): CTA ordering.
182+
"""
139183
version: List[int]
140184
warps_per_cta: List[int]
141185
instr_shape: List[int]
@@ -166,11 +210,27 @@ def mangle(self) -> str:
166210

167211

168212
class SharedLayout:
213+
"""
214+
Base class for shared memory layouts in Gluon IR.
215+
"""
169216
pass
170217

171218

172219
@dataclass(frozen=True)
173220
class NVMMASharedLayout(SharedLayout):
221+
"""
222+
Represents a layout for shared memory suitable for NVIDIA MMA operations.
223+
224+
Args:
225+
swizzle_byte_width (int): Width in bytes for swizzling.
226+
element_bitwidth (int): Bitwidth of element type.
227+
rank (int): Rank of the tensor.
228+
transposed (bool): Whether the layout is transposed.
229+
fp4_padded (bool): Whether FP4 padding is used.
230+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
231+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
232+
cta_order (Optional[List[int]]): CTA ordering.
233+
"""
174234
swizzle_byte_width: int
175235
element_bitwidth: int
176236
rank: int
@@ -215,6 +275,18 @@ def mangle(self) -> str:
215275

216276
@dataclass(frozen=True, eq=True)
217277
class SwizzledSharedLayout(SharedLayout):
278+
"""
279+
Represents a generic swizzled shared memory layout.
280+
281+
Args:
282+
vec (int): Vector width for swizzling.
283+
per_phase (int): Elements per swizzle phase.
284+
max_phase (int): Maximum number of swizzle phases.
285+
order (List[int]): Dimension ordering for swizzling.
286+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
287+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
288+
cta_order (Optional[List[int]]): CTA ordering.
289+
"""
218290
vec: int
219291
per_phase: int
220292
max_phase: int

0 commit comments

Comments
 (0)