@@ -191,6 +191,9 @@ def mangle(self) -> str:
191191
192192
193193class shared_memory_descriptor (base_value ):
194+ """
195+ Represents a handle to a shared memory allocation in Gluon IR.
196+ """
194197
195198 def __init__ (self , handle , element_ty , shape , layout , alloc_shape ):
196199 self .handle = handle
@@ -220,39 +223,104 @@ def __str__(self) -> str:
220223
221224 @builtin
222225 def load (self , layout , _semantic : GluonSemantic ) -> tensor :
226+ """
227+ Load a tensor from shared memory.
228+
229+ Args:
230+ layout (DistributedLayout): The destination layout of the tensor.
231+
232+ Returns:
233+ tensor: A Gluon tensor containing the loaded data.
234+ """
223235 layout = _unwrap_if_constexpr (layout )
224236 return _semantic .shared_load (self , layout )
225237
226238 @builtin
227239 def store (self , value , _semantic : GluonSemantic ) -> None :
240+ """
241+ Store a tensor into shared memory.
242+
243+ Args:
244+ value (tensor): The tensor whose contents to store.
245+ """
228246 return _semantic .shared_store (self , value )
229247
230248 @builtin
231249 def slice (self , start , length , dim = 0 , _semantic : GluonSemantic = None ) -> shared_memory_descriptor :
250+ """
251+ Create a subview of shared memory by slicing along a given dimension.
252+
253+ Args:
254+ start (int): The starting index of the slice.
255+ length (int): The length of the slice.
256+ dim (int): The dimension to slice (default: 0).
257+
258+ Returns:
259+ shared_memory_descriptor: Descriptor for the sliced subview.
260+ """
232261 start = _unwrap_if_constexpr (start )
233262 length = _unwrap_if_constexpr (length )
234263 dim = _unwrap_if_constexpr (dim )
235264 return _semantic .memdesc_slice (self , start , length , dim )
236265
237266 @builtin
238267 def index (self , index , _semantic : GluonSemantic = None ) -> shared_memory_descriptor :
268+ """
269+ Create a subview of shared memory by indexing along the first dimension.
270+
271+ Args:
272+ index (int): The index at which to take the subview.
273+
274+ Returns:
275+ shared_memory_descriptor: Descriptor for the indexed subview.
276+ """
239277 index = _unwrap_if_constexpr (index )
240278 return _semantic .memdesc_index (self , index )
241279
242280 @builtin
243281 def permute (self , order , _semantic : GluonSemantic ) -> shared_memory_descriptor :
282+ """
283+ Permute the dimensions of the shared memory descriptor.
284+
285+ Args:
286+ order (List[int]): The new ordering of dimensions.
287+
288+ Returns:
289+ shared_memory_descriptor: Descriptor with permuted dimensions.
290+ """
244291 order = [_unwrap_if_constexpr (o ) for o in order ]
245292 return _semantic .memdesc_trans (self , order )
246293
247294 @builtin
248295 def reshape (self , shape , layout , _semantic : GluonSemantic ) -> shared_memory_descriptor :
296+ """
297+ Reshape the shared memory descriptor to a new shape and layout.
298+
299+ Args:
300+ shape (List[int]): The target shape.
301+ layout (SharedLayout): The new layout for the descriptor.
302+
303+ Returns:
304+ shared_memory_descriptor: Descriptor with the new shape and layout.
305+ """
249306 shape = [_unwrap_if_constexpr (s ) for s in shape ]
250307 layout = _unwrap_if_constexpr (layout )
251308
252309 return _semantic .memdesc_reshape (self , shape , layout )
253310
254311 @builtin
255312 def _reinterpret (self , dtype , shape , layout , _semantic : GluonSemantic = None ) -> shared_memory_descriptor :
313+ """
314+ Reinterpret the shared memory descriptor as a different dtype, shape, or layout.
315+
316+ Args:
317+ dtype (dtype): The new data type.
318+ shape (List[int]): The new shape.
319+ layout (SharedLayout): The new layout.
320+
321+ Returns:
322+ shared_memory_descriptor: Descriptor with updated type and layout.
323+ """
256324 dtype = _unwrap_if_constexpr (dtype )
257325 shape = [_unwrap_if_constexpr (s ) for s in shape ]
258326 layout = _unwrap_if_constexpr (layout )
@@ -261,6 +329,9 @@ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) ->
261329
262330 @builtin
263331 def _keep_alive (self , _semantic : GluonSemantic = None ) -> None :
332+ """
333+ Dummy use to keep the shared memory descriptor alive.
334+ """
264335 return _semantic .shared_dealloc (self )
265336
266337
@@ -271,6 +342,17 @@ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
271342
272343@builtin
273344def arange (start , end , layout , _semantic = None ):
345+ """
346+ Generate a sequence tensor with values in [start, end) using a specified layout.
347+
348+ Args:
349+ start (int): Inclusive start of the sequence.
350+ end (int): Exclusive end of the sequence.
351+ layout (DistributedLayout): The layout of the output tensor.
352+
353+ Returns:
354+ tensor: A 1D tensor containing sequential values.
355+ """
274356 start = _unwrap_if_constexpr (start )
275357 end = _unwrap_if_constexpr (end )
276358 layout = _unwrap_if_constexpr (layout )
@@ -279,12 +361,34 @@ def arange(start, end, layout, _semantic=None):
279361
280362@builtin
281363def convert_layout (value , layout , _semantic = None ):
364+ """
365+ Convert a tensor to a different distributed layout.
366+
367+ Args:
368+ value (tensor): The input tensor.
369+ layout (DistributedLayout): The target layout.
370+
371+ Returns:
372+ tensor: The tensor with the new layout.
373+ """
282374 layout = _unwrap_if_constexpr (layout )
283375 return _semantic .convert_layout (value , layout )
284376
285377
286378@builtin
287379def full (shape , value , dtype , layout , _semantic = None ):
380+ """
381+ Create a tensor filled with a scalar value, with specified shape, dtype, and layout.
382+
383+ Args:
384+ shape (Sequence[int]): The shape of the tensor.
385+ value (int or float): The fill value.
386+ dtype (dtype): The data type for the tensor.
387+ layout (DistributedLayout): The layout of the output tensor.
388+
389+ Returns:
390+ tensor: A tensor where every element equals value.
391+ """
288392 shape = _unwrap_shape (shape )
289393 value = _unwrap_if_constexpr (value )
290394 dtype = _unwrap_if_constexpr (dtype )
@@ -294,6 +398,18 @@ def full(shape, value, dtype, layout, _semantic=None):
294398
295399@builtin
296400def allocate_shared_memory (element_ty , shape , layout , value = None , _semantic = None ):
401+ """
402+ Allocate shared memory for a tensor with the given element type, shape, and layout.
403+
404+ Args:
405+ element_ty (dtype): The element data type.
406+ shape (Sequence[int]): The dimensions of the shared memory.
407+ layout (SharedLayout): The shared memory layout.
408+ value (tensor, optional): Initial value to copy into shared memory.
409+
410+ Returns:
411+ shared_memory_descriptor: Descriptor for the allocated memory.
412+ """
297413 element_ty = _unwrap_if_constexpr (element_ty )
298414 shape = _unwrap_if_constexpr (shape )
299415 shape = [_unwrap_if_constexpr (s ) for s in shape ]
@@ -304,6 +420,20 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
304420@builtin
305421def warp_specialize (default_args , default_partition , worker_args , worker_partitions , worker_num_warps , worker_num_regs ,
306422 _semantic = None , _generator = None ):
423+ """
424+ Create a warp-specialized execution region, partitioning work across warps.
425+
426+ Args:
427+ default_args (List[Any]): Arguments for the default region.
428+ default_partition (callable): Function to build the default execution region.
429+ worker_args (List[Any]): Arguments for each warp partition.
430+ worker_partitions (List[callable]): Functions for each warp partition.
431+ worker_num_warps (List[int]): Number of warps per partition.
432+ worker_num_regs (List[int]): Number of registers per partition.
433+
434+ Returns:
435+ Tuple[Any, ...]: Results from the default region.
436+ """
307437 worker_num_warps = [_unwrap_if_constexpr (w ) for w in worker_num_warps ]
308438 worker_num_regs = [_unwrap_if_constexpr (r ) for r in worker_num_regs ]
309439 return _semantic .warp_specialize (default_args , default_partition , worker_args , worker_partitions , worker_num_warps ,
@@ -312,4 +442,7 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
312442
313443@builtin
314444def thread_barrier (_semantic = None ):
445+ """
446+ Insert a barrier to synchronize threads within a CTA.
447+ """
315448 return _semantic .debug_barrier ()
0 commit comments