@@ -252,6 +252,24 @@ def named_region(self, *args, **kwargs):
252252 else :
253253 yield
254254
255+ def cluster_idx (
256+ self , dim : gpu .Dimension | Sequence [gpu .Dimension ] | None = None
257+ ) -> ir .Value :
258+ """Returns the index of a block within a subset of the cluster spanned by the given dimensions."""
259+ if dim is None :
260+ dim = gpu .Dimension
261+ elif isinstance (dim , gpu .Dimension ):
262+ dim = (dim ,)
263+ index = ir .IndexType .get ()
264+ stride = 1
265+ idx = c (0 , index )
266+ for d in sorted (dim ):
267+ if self .cluster_size [d ] == 1 : # Optimize a multiply by 0.
268+ continue
269+ idx = arith .addi (idx , arith .muli (gpu .cluster_block_id (d ), c (stride , index )))
270+ stride *= self .cluster_size [d ]
271+ return idx
272+
255273 def _alloc_scratch (
256274 self ,
257275 size : int ,
@@ -355,8 +373,35 @@ def async_copy(
355373 arrive : bool | None = None ,
356374 uniform : bool = True ,
357375 collective : Sequence [gpu .Dimension ] | gpu .Dimension | None = None ,
376+ partitioned : int | None = None ,
358377 predicate : ir .Value | None = None , # Should select 0 or 1 threads from the WG.
359378 ):
379+ """Initiates an async copy between GMEM and SMEM.
380+
381+ Exactly one of `src_ref` and `dst_ref` must be in GMEM and in SMEM, and the
382+ SMEM reference must be contiguous. The GMEM window that is read or written
383+ to is specified by the `gmem_slice`. The copy can change the order in which
384+ the data appears in the window by applying a sequence of transforms to the
385+ GMEM reference (as specified by `gmem_transform`).
386+
387+ When `collective` is specified (only allowed for GMEM -> SMEM copies), the
388+ identical async_copy must be scheduled by all blocks that share the same
389+ coordinates along collective dimensions within a cluster. The behavior is
390+ undefined otherwise. The semantics of collective loads depend further on the
391+ `partitioned` argument:
392+
393+ - If `partitioned` is not specified, all blocks load the same data into
394+ their shared memory and all receive the update in their barriers, unless
395+ `arrive` is False. If `arrive` is False, you should expect the barrier to
396+ have expect_tx incremented by the same amount of bytes as if `collective`
397+ was not specified.
398+ - If `partitioned` is specified, each block only loads a separate slice of
399+ the data into SMEM, partitioned into equal tiles along the `partitioned`
400+ dimension. In this case only the barrier of the first block in the
401+ collective will have its expect_tx incremented by the total size of the
402+ transfer across all blocks involved in the collective. Barriers supplied
403+ by other blocks will be ignored (even if `arrive` is True).
404+ """
360405 index = ir .IndexType .get ()
361406 i16 = ir .IntegerType .get_signless (16 )
362407 i32 = ir .IntegerType .get_signless (32 )
@@ -408,13 +453,46 @@ def async_copy(
408453 " multiple of 16 bytes"
409454 )
410455
411- # TMA supports OOB indices, so we skip the check.
456+ # NOTE: TMA supports OOB indices, so we skip the check.
412457 base_indices , slice_shape , is_squeezed = utils .parse_indices (
413458 gmem_slice , ir .MemRefType (gmem_ref .type ).shape , check_oob = False
414459 )
415460 dyn_base_indices = tuple (
416461 c (i , index ) if not isinstance (i , ir .Value ) else i for i in base_indices
417462 )
463+ del base_indices # Use the dynamic indices from now on!
464+
465+ collective_size = 1
466+ if collective is not None :
467+ if isinstance (collective , gpu .Dimension ):
468+ collective = (collective ,)
469+ collective_size = math .prod (self .cluster_size [d ] for d in collective )
470+ if gmem_ref is dst_ref :
471+ raise ValueError ("Only GMEM -> SMEM copies can be collective" )
472+ if partitioned is not None :
473+ if collective is None :
474+ raise ValueError ("Only collective loads can be partitioned" )
475+ if collective_size > 1 and partitioned is not None :
476+ if math .prod (self .cluster_size ) != 2 :
477+ raise NotImplementedError (
478+ "Partitioned loads only supported for clusters of size 2"
479+ )
480+ if slice_shape [partitioned ] % collective_size != 0 :
481+ raise ValueError (
482+ f"The collective size ({ collective_size } ) must divide the slice"
483+ " shape along the partitioned dimension, but it has size"
484+ f" { slice_shape [partitioned ]} "
485+ )
486+ slice_shape [partitioned ] //= collective_size
487+ dyn_base_indices = list (dyn_base_indices )
488+ dyn_base_indices [partitioned ] = arith .addi (
489+ dyn_base_indices [partitioned ],
490+ arith .muli (
491+ self .cluster_idx (collective ), c (slice_shape [partitioned ], index )
492+ ),
493+ )
494+ dyn_base_indices = tuple (dyn_base_indices )
495+
418496 squeezed_dims = [i for i , squeezed in enumerate (is_squeezed ) if squeezed ]
419497 sliced_dims = [i for i , squeezed in enumerate (is_squeezed ) if not squeezed ]
420498 # Indexing is really slicing + squeezing, and user transforms are meant to
@@ -472,12 +550,9 @@ def async_copy(
472550 dyn_base_indices = list (dyn_base_indices )
473551 slice_shape = list (slice_shape )
474552 assert all (d == 1 for d in slice_shape [:num_squeezed_dims ])
475- collective_size = 1
476- if collective is not None :
477- if isinstance (collective , gpu .Dimension ):
478- collective = (collective ,)
479- collective_size = math .prod (self .cluster_size [d ] for d in collective )
480- if collective_size > 1 :
553+
554+ # Partitioned loads have already been processed (before transforms).
555+ if collective_size > 1 and partitioned is None :
481556 def partition_dim (dim : int , idx : ir .Value , num_chunks : int ):
482557 # No need to partition squeezed dims. They don't even exist in smem_ref.
483558 assert dim >= num_squeezed_dims
@@ -490,13 +565,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
490565 (slice (None ),) * (dim - num_squeezed_dims )
491566 + (utils .ds (block_offset , slice_shape [dim ]),),
492567 )
493- stride = 1
494- idx = c (0 , index )
495- for d in sorted (collective ):
496- if self .cluster_size [d ] == 1 : # Optimize a multiply by 0.
497- continue
498- idx = arith .addi (idx , arith .muli (gpu .cluster_block_id (d ), c (stride , index )))
499- stride *= self .cluster_size [d ]
568+ idx = self .cluster_idx (collective )
500569 rem_collective_size = collective_size
501570 for dim , slice_size in enumerate (slice_shape [:- 1 ]):
502571 if slice_size % rem_collective_size == 0 :
@@ -572,15 +641,44 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
572641 )
573642 barrier_ptr = barrier .get_ptr ()
574643 with uniform_ctx ():
575- if arrive :
576- nvvm .mbarrier_arrive_expect_tx_shared (
577- barrier_ptr , transfer_bytes , predicate = predicate
644+ if collective_size > 1 and partitioned is not None :
645+ if predicate is None :
646+ predicate = c (1 , ir .IntegerType .get_signless (1 ))
647+ if arrive :
648+ first_block = arith .cmpi (
649+ arith .CmpIPredicate .eq , self .cluster_idx (collective ), c (0 , index ),
650+ )
651+ arrive_predicate = arith .andi (predicate , first_block )
652+ nvvm .mbarrier_arrive_expect_tx_shared (
653+ barrier_ptr , transfer_bytes , predicate = arrive_predicate
654+ )
655+ rank = len (slice_shape )
656+ idx_operands = "," .join (f"${ i } " for i in range (4 , 4 + rank ))
657+ llvm .inline_asm (
658+ ir .Type .parse ("!llvm.void" ),
659+ [predicate , smem_ptr , tma_desc , barrier_ptr , * rev_dyn_base_indices ],
660+ f"""
661+ {{
662+ .reg .b32 mapped_addr;
663+ @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0;
664+ @$0 cp.async.bulk.tensor.{ rank } d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2
665+ [$1], [$2, {{{ idx_operands } }}], [mapped_addr];
666+ }}
667+ """ ,
668+ "b,r,l,r" + ",r" * rank ,
669+ has_side_effects = True ,
670+ )
671+ else :
672+ if arrive :
673+ nvvm .mbarrier_arrive_expect_tx_shared (
674+ barrier_ptr , transfer_bytes , predicate = predicate
675+ )
676+ nvvm .cp_async_bulk_tensor_shared_cluster_global (
677+ smem_ptr , tma_desc , rev_dyn_base_indices , barrier_ptr , [],
678+ multicast_mask = multicast_mask , predicate = predicate
578679 )
579- nvvm .cp_async_bulk_tensor_shared_cluster_global (
580- smem_ptr , tma_desc , rev_dyn_base_indices , barrier_ptr , [],
581- multicast_mask = multicast_mask , predicate = predicate
582- )
583680 else :
681+ assert multicast_mask is None
584682 with uniform_ctx ():
585683 nvvm .cp_async_bulk_tensor_global_shared_cta (
586684 tma_desc , smem_ptr , rev_dyn_base_indices , predicate = predicate
0 commit comments