@@ -327,24 +327,30 @@ def tmem_relinquish_alloc_permit():
327327 has_side_effects = True ,
328328 )
329329
330- def tmem_load ( tmem_addr , shape , num , packing : int = 1 ):
330+ def _tmem_access_helper ( shape , num , packing : int = 1 ):
331331 if num .bit_count () != 1 or num > 128 :
332332 raise ValueError (f"num must be a power of 2 and <= 128, got: { num } " )
333333 match shape :
334334 case "16x128b" :
335- num_out_regs = 2
335+ num_regs = 2
336336 case "16x256b" :
337- num_out_regs = 4
337+ num_regs = 4
338338 case _:
339339 raise NotImplementedError (f"{ shape = } is unsupported" )
340- if num * num_out_regs >= 256 :
340+ num_regs *= num
341+ if num_regs > 255 :
341342 raise ValueError (
342- f"Loading too much TMEM at once : { num = } and each load requires "
343- f" { num_out_regs } registers, which exceeds the limit of 256 "
343+ f"TMEM transation too big : { shape = } and { num = } involve "
344+ f" { num_regs } registers per-thread , which exceeds the limit of 255 "
344345 )
345- num_out_regs *= num
346+ regs_vector = "," .join (f"${ i } " for i in range (num_regs ))
347+ regs_vector = "{" + regs_vector + "}"
348+ return num_regs , regs_vector
349+
350+
351+ def tmem_load (tmem_addr , shape , num , packing : int = 1 ):
346352 i32 = ir .IntegerType .get_signless (32 )
347- out_regs = "," . join ( "$" + str ( i ) for i in range ( num_out_regs ) )
353+ num_out_regs , regs_vector = _tmem_access_helper ( shape , num , packing )
348354 if packing == 1 :
349355 pack_mod = ""
350356 elif packing == 2 :
@@ -356,13 +362,30 @@ def tmem_load(tmem_addr, shape, num, packing: int = 1):
356362 "!llvm.struct<(" + "," .join ("i32" for _ in range (num_out_regs )) + ")>"
357363 ),
358364 [tmem_addr ],
359- f"tcgen05.ld.sync.aligned.{ shape } .x{ num } { pack_mod } .b32 {{ { out_regs } } }, [${ num_out_regs } ];" ,
365+ f"tcgen05.ld.sync.aligned.{ shape } .x{ num } { pack_mod } .b32 { regs_vector } , [${ num_out_regs } ];" ,
360366 "=r," * num_out_regs + "r" ,
361367 has_side_effects = True ,
362368 )
363369 return [llvm .extractvalue (i32 , regs , [i ]) for i in range (num_out_regs )]
364370
365371
372+ def tmem_store (tmem_addr , shape , num , regs , packing : int = 1 ):
373+ num_out_regs , regs_vector = _tmem_access_helper (shape , num , packing )
374+ if packing == 1 :
375+ pack_mod = ""
376+ elif packing == 2 :
377+ pack_mod = ".unpack::16b"
378+ else :
379+ raise ValueError (f"Unsupported packing: { packing } " )
380+ llvm .inline_asm (
381+ ir .Type .parse ("!llvm.void" ),
382+ [* regs , tmem_addr ],
383+ f"tcgen05.st.sync.aligned.{ shape } .x{ num } { pack_mod } .b32 [${ num_out_regs } ], { regs_vector } ;" ,
384+ "r," * num_out_regs + "r" ,
385+ has_side_effects = True ,
386+ )
387+
388+
366389@dataclasses .dataclass (frozen = True )
367390class TMEMLayout :
368391 """Represents the way a shape is laid out in TMEM.
@@ -562,62 +585,168 @@ def __getitem__(self, *idxs):
562585 )
563586 return fa .FragmentedArray (_registers = registers , _layout = layout , _is_signed = None )
564587
588+ def __setitem__ (self , idxs , value ):
589+ if not isinstance (idxs , tuple ):
590+ idxs = (idxs ,)
591+ base_idxs , slice_shape , is_squeezed = utils .parse_indices (idxs , self .shape )
592+ if any (is_squeezed ):
593+ raise ValueError (
594+ "TMEM stores don't support integer indexing (only slices allowed)"
595+ )
596+ if any (idx != 0 for idx in base_idxs ) or tuple (slice_shape ) != self .shape :
597+ raise NotImplementedError ("Slicing parts of TMEM not implemented yet" )
598+ if self .shape [1 ] % 8 :
599+ raise NotImplementedError
600+ if utils .bitwidth (self .dtype ) not in {16 , 32 }:
601+ raise NotImplementedError (f"Unsupported dtype: { self .dtype } " )
602+ if not isinstance (value , fa .FragmentedArray ):
603+ raise ValueError (f"TMEM stores expect a FragmentedArray, got: { value } " )
604+ if value .shape != self .shape :
605+ raise ValueError (
606+ f"Stored array has shape { value .shape } , but TMEM has shape"
607+ f" { self .shape } "
608+ )
609+ if value .mlir_dtype != self .dtype :
610+ raise ValueError (
611+ f"Stored array has dtype { value .mlir_dtype } , but TMEM has dtype"
612+ f" { self .dtype } "
613+ )
614+ if value .layout != LAYOUT :
615+ raise ValueError (
616+ f"Stored array has layout { value .layout } , but only tcgen05.LAYOUT is"
617+ " supported"
618+ )
619+ if self .layout == TMEMLayout (elements_in_tile = (TMEM_ROWS , 8 )):
620+ # store_32xcols needs a 4xN array, but the FA tiling we use here tiles
621+ # columns before rows, and so it is Nx4 (after ignoring all 1 dims).
622+ _store_32xcols (
623+ self .address , value .registers .T .reshape ((4 , - 1 ))
624+ )
625+ else : # TODO(apaszke): Collective MMA layout
626+ raise NotImplementedError (
627+ f"Stores only implemented for refs with standard layout, got: { self .layout } "
628+ )
629+
630+
631+ def _transfer_32xcols (base_addr , cols ):
632+ i32 = ir .IntegerType .get_signless (32 )
633+ cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT.
634+ assert cols % cols_per_num == 0
635+ total_num = cols // cols_per_num
636+ if total_num <= 32 :
637+ instr_num = total_num
638+ elif total_num == 64 :
639+ instr_num = 32
640+ else :
641+ raise NotImplementedError (total_num )
642+ # We transfer 16 lanes at a time, but have 32 to deal with.
643+ for lane_step in range (2 ):
644+ addr_row = arith .addi (base_addr , utils .c ((lane_step * 16 ) << 16 , i32 ))
645+ cols_per_instr = instr_num * cols_per_num
646+ for num_step in range (total_num // instr_num ):
647+ num_slice = slice (num_step * instr_num , (num_step + 1 ) * instr_num )
648+ addr_row_col = arith .addi (addr_row , utils .c (num_step * cols_per_instr , i32 ))
649+ yield addr_row_col , instr_num , lane_step , num_slice
650+
651+
652+ def _store_32xcols (base_addr , vector_regs ):
653+ i32 = ir .IntegerType .get_signless (32 )
654+ assert vector_regs .ndim == 2 and vector_regs .shape [0 ] == 4
655+ cols = vector_regs .shape [1 ] * 8
656+
657+ packing = 64 // utils .bitwidth (vector_regs .flat [0 ].type )
658+ if packing == 1 :
659+ store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
660+ regs = np .empty ((4 , vector_regs .shape [1 ], 2 ), dtype = object )
661+ c0 = arith .constant (i32 , 0 )
662+ c1 = arith .constant (i32 , 1 )
663+ for idx , vreg in np .ndenumerate (vector_regs ):
664+ regs [(* idx , 0 )] = llvm .extractelement (vreg , c0 )
665+ regs [(* idx , 1 )] = llvm .extractelement (vreg , c1 )
666+ regs = regs .reshape (2 , 2 , vector_regs .shape [1 ], 2 ).swapaxes (1 , 2 )
667+ # From a single lane perspective a num tile consists of a 2x2, with the
668+ # minor dim traversing columns and major being 8 rows apart.
669+ # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
670+ assert regs .shape [- 2 :] == (2 , 2 )
671+ elif packing == 2 :
672+ store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
673+ # From a single lane perspective a num tile has 2 registers, 8 rows apart.
674+ # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
675+ regs = vector_regs .reshape (2 , 2 , vector_regs .shape [1 ]).swapaxes (1 , 2 )
676+ else :
677+ raise NotImplementedError (packing )
678+
679+ it = _transfer_32xcols (base_addr , cols )
680+ for addr_row_col , instr_num , lane_step , num_slice in it :
681+ regs_slice = regs [lane_step , num_slice ].flat
682+ tmem_store (addr_row_col , store_shape , instr_num , regs_slice , packing )
683+
565684
566685def _load_32xcols (base_addr , cols , dtype ):
567- # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
568686 i32 = ir .IntegerType .get_signless (32 )
687+ vec_ty = ir .VectorType .get ((2 ,), dtype )
569688 packing = 32 // utils .bitwidth (dtype )
570689 if packing == 1 :
571- load_shape = "16x256b" # 8 columns * 32 bits = 256 bits
572- cols_per_num_tile = 8 * packing
690+ load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
573691 elif packing == 2 :
574- load_shape = "16x128b" # 8 columns * 16 bits = 128 bits
575- cols_per_num_tile = 4 * packing
692+ load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
576693 else :
577694 raise NotImplementedError (packing )
578- assert cols % cols_per_num_tile == 0
579- num = cols // cols_per_num_tile
580- if num <= 32 :
581- num_tiling = num
582- elif num == 64 :
583- num_tiling = 32
584- else :
585- raise NotImplementedError (num )
695+
586696 vector_regs = np .ndarray ((4 , cols // 8 ), dtype = object )
587- # We load 16 lanes at a time, but need 32 in total.
588- for row_group in range ( 2 ):
589- addr_row = arith .addi ( base_addr , arith . constant (i32 , ( row_group * 16 ) << 16 ) )
590- regs = []
591- for num_group in range ( num // num_tiling ) :
592- addr_row_col = arith . addi (
593- addr_row ,
594- arith . constant ( i32 , num_tiling * num_group * cols_per_num_tile ),
595- )
596- regs += tmem_load ( addr_row_col , load_shape , num_tiling , packing )
697+
698+ it = _transfer_32xcols ( base_addr , cols )
699+ c0 = arith .constant (i32 , 0 )
700+ c1 = arith . constant ( i32 , 1 )
701+ for addr_row_col , instr_num , lane_step , num_slice in it :
702+ regs = tmem_load ( addr_row_col , load_shape , instr_num , packing )
703+ row_slice = slice ( lane_step * 2 , ( lane_step + 1 ) * 2 )
704+ # This aliases the original array, so updates will be reflected there.
705+ vector_regs_update = vector_regs [ row_slice , num_slice ]
706+ assert vector_regs_update . shape == ( 2 , instr_num ), ( vector_regs_update . shape , instr_num )
597707 if packing == 1 :
598708 regs = [llvm .bitcast (dtype , r ) for r in regs ]
599- undef = llvm .mlir_undef (ir .VectorType .get ((2 ,), dtype ))
600- for r_low , r_high , idx in zip (regs [::2 ], regs [1 ::2 ], np .ndindex (cols // 8 , 2 ), strict = True ):
601- high_undef = llvm .insertelement (undef , r_low , utils .c (0 , i32 ))
602- vreg = llvm .insertelement (high_undef , r_high , utils .c (1 , i32 ))
603- vector_regs [idx [1 ] + 2 * row_group , idx [0 ]] = vreg
709+ # From a single lane perspective a num tile consists of a 2x2, with the
710+ # minor dim traversing columns and major being 8 rows apart.
711+ # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
712+ regs = np .asarray (regs , dtype = object ).reshape (instr_num , 2 , 2 ).swapaxes (0 , 1 )
713+ undef = llvm .mlir_undef (vec_ty )
714+ assert regs .shape == (* vector_regs_update .shape , 2 )
715+ for idx in np .ndindex (vector_regs_update .shape ):
716+ high_undef = llvm .insertelement (undef , regs [(* idx , 0 )], c0 )
717+ vreg = llvm .insertelement (high_undef , regs [(* idx , 1 )], c1 )
718+ vector_regs_update [idx ] = vreg
604719 else :
605720 assert packing == 2
606- regs = [llvm .bitcast (ir .VectorType .get ((2 ,), dtype ), r ) for r in regs ]
607- for vreg , idx in zip (regs , np .ndindex (cols // 8 , 2 ), strict = True ):
608- vector_regs [idx [1 ] + 2 * row_group , idx [0 ]] = vreg
721+ regs = [llvm .bitcast (vec_ty , r ) for r in regs ]
722+ # From a single lane perspective a num tile has 2 registers, 8 rows apart.
723+ # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
724+ regs = np .asarray (regs , dtype = object ).reshape (instr_num , 2 ).swapaxes (0 , 1 )
725+ vector_regs_update [...] = regs
726+
609727 return vector_regs
610728
611729
612- # Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN.
613730def _m128_layout (shape : tuple [int , ...]):
614731 if len (shape ) != 2 :
615732 raise ValueError (f"Shape { shape } is not 2D" )
616733 if shape [0 ] % 128 != 0 or shape [1 ] % 8 != 0 :
617734 raise ValueError (f"Shape { shape } is not a multiple of 64x8" )
618- return fa .TiledLayout (
619- fa .Tiling (((128 , 8 ), (32 , 8 ), (8 , 8 ), (1 , 2 ))),
620- warp_dim = - 8 ,
621- lane_dims = (- 4 , - 3 ),
622- vector_dim = - 1 ,
735+ return LAYOUT
736+
737+ # Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN.
738+ # The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT)
739+ LAYOUT = fa .TiledLayout (
740+ fa .Tiling (((128 , 8 ), (32 , 8 ), (8 , 8 ), (1 , 2 ))),
741+ warp_dim = - 8 ,
742+ lane_dims = (- 4 , - 3 ),
743+ vector_dim = - 1 ,
744+ )
745+
746+
747+ def commit_tmem ():
748+ void = ir .Type .parse ("!llvm.void" )
749+ llvm .inline_asm (
750+ void , [], "tcgen05.wait::st.sync.aligned;" , "" , has_side_effects = True ,
623751 )
752+ utils .warpgroup_barrier ()
0 commit comments