@@ -526,6 +526,18 @@ def linear_thread_idxs(self):
526526 lane_dims = (- 4 , - 2 , - 3 ),
527527 vector_dim = - 1 ,
528528)
529+ # This layout should be used when upcasting 4-bit elements to 16-bit, for the
530+ # purpose of passing them into WGMMA later. The core matrices stored by a warp
531+ # are 8x32, because each of the 4 threads in a row holds 8 elements in a single
532+ # vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
533+ # group of 4 threads in order (as opposed to the swapping between 1 and 2,
534+ # 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
535+ WGMMA_LAYOUT_UPCAST_4X = TiledLayout (
536+ Tiling (((64 , 32 ), (16 , 32 ), (8 , 32 ), (8 ,))),
537+ warp_dim = - 7 ,
538+ lane_dims = (- 3 , - 2 ),
539+ vector_dim = - 1 ,
540+ )
529541# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
530542# submatrix in the following way (we only show the first 4 rows for brevity):
531543#
@@ -739,58 +751,132 @@ def to_layout(self, new_layout: FragmentedLayout):
739751 _layout = new_layout ,
740752 _is_signed = self .is_signed ,
741753 )
742- if len (shape ) == 2 and shape [0 ] % 64 == 0 and shape [1 ] % 16 == 0 :
743- if (
744- self .layout == WGMMA_LAYOUT_UPCAST_2X
745- and new_layout == WGMMA_LAYOUT
746- and (dtype_bitwidth := utils .bitwidth (self .mlir_dtype )) in {8 , 16 }
747- ):
748- assert shape [1 ] % 16 == 0 # Should be implied by the layout
749- new_registers = np .empty (new_layout .registers_shape (shape ), dtype = object )
750- is_even = arith .cmpi (
751- arith .CmpIPredicate .eq , arith .remui (utils .thread_idx (), c (2 )), c (0 )
754+ if (
755+ self .layout == WGMMA_LAYOUT_UPCAST_2X
756+ and new_layout == WGMMA_LAYOUT
757+ and (dtype_bitwidth := utils .bitwidth (self .mlir_dtype )) <= 16
758+ ):
759+ assert shape [1 ] % 16 == 0 # Should be implied by the layout
760+ new_registers = np .empty (new_layout .registers_shape (shape ), dtype = object )
761+ is_even = arith .cmpi (
762+ arith .CmpIPredicate .eq , arith .remui (utils .thread_idx (), c (2 )), c (0 )
763+ )
764+ registers = self .registers
765+ if dtype_bitwidth == 4 :
766+ if registers .shape [1 ] % 2 :
767+ raise NotImplementedError (
768+ "This relayout implementation requires an even number of column"
769+ " tiles (to pack pairs of them for efficiency)"
770+ )
771+ # We pair up the consecutive column tiles, so each register is 32-bit.
772+ # If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
773+ # LLVM will realize that the paired up vectors actually came from the
774+ # same 32-bit register and it will become a no-op.
775+ col_minor_registers = np .moveaxis (registers , 1 , - 1 )
776+ flat_registers = [
777+ utils .vector_concat ((l , h ))
778+ for l , h in zip (
779+ col_minor_registers .flat [::2 ], col_minor_registers .flat [1 ::2 ]
780+ )
781+ ]
782+ registers = np .asarray (flat_registers , dtype = object ).reshape (
783+ * col_minor_registers .shape [:- 1 ], col_minor_registers .shape [- 1 ] // 2
752784 )
753- for idx , reg in np .ndenumerate (self .registers ):
754- assert ir .VectorType (reg .type ).shape == [4 ]
755- if dtype_bitwidth == 16 :
756- # A single vector is 64-bits, but shuffles are only 32-bit wide.
757- # We only shuffle the half that needs to go to other thread.
758- low = utils .vector_slice (reg , slice (0 , 2 ))
759- high = utils .vector_slice (reg , slice (2 , 4 ))
760- to_exchange = arith .select (is_even , high , low )
761- # Exchange values between even and odd threads.
762- exchanged = utils .shfl_bfly (to_exchange , 1 )
763- low = arith .select (is_even , low , exchanged )
764- high = arith .select (is_even , exchanged , high )
765- elif dtype_bitwidth == 8 :
766- # The vector is 32-bits, so we just shuffle the whole thing and
767- # use prmt to blend it with the local register.
768- exchanged = utils .shfl_bfly (reg , 1 )
769- # Consider lanes 0 and 1, because the situation is symmetric for
770- # each pair. If we feed reg[lane] and exchanged[lane] (which is
771- # really the same as reg of the other lane) to prmt, we can index
772- # the elements of the result using the following indices:
773- # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
774- # prmt[0]: 0 1 2 3 4 5 6 7
775- # prmt[1]: 4 5 6 7 0 1 2 3
776- # The expected outputs and their respective permutations are:
777- # out[0]: 0 1 8 9 out[1]: 2 3 10 11
778- # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
779- # Note that the patterns still need to be flipped, since we listed
780- # bytes with LSB on the left, which is the opposite of how the
781- # numeric constants are spelled in Python (LSB on the right).
782- perm = arith .select (is_even , c (0x5410 ), c (0x3276 ))
783- blend = utils .prmt (reg , exchanged , perm )
784- low = utils .vector_slice (blend , slice (0 , 2 ))
785- high = utils .vector_slice (blend , slice (2 , 4 ))
786- else :
787- raise NotImplementedError (dtype_bitwidth )
785+ registers = np .moveaxis (registers , - 1 , 1 )
786+ for idx , reg in np .ndenumerate (registers ):
787+ if dtype_bitwidth == 16 :
788+ assert reg .type .shape == [4 ]
789+ # A single vector is 64-bits, but shuffles are only 32-bit wide.
790+ # We only shuffle the half that needs to go to other thread.
791+ low = utils .vector_slice (reg , slice (0 , 2 ))
792+ high = utils .vector_slice (reg , slice (2 , 4 ))
793+ to_exchange = arith .select (is_even , high , low )
794+ # Exchange values between even and odd threads.
795+ exchanged = utils .shfl_bfly (to_exchange , 1 )
796+ low = arith .select (is_even , low , exchanged )
797+ high = arith .select (is_even , exchanged , high )
788798 new_registers [(idx [0 ], idx [1 ] * 2 , * idx [2 :- 1 ])] = low
789799 new_registers [(idx [0 ], idx [1 ] * 2 + 1 , * idx [2 :- 1 ])] = high
790- assert all (r is not None for r in new_registers )
791- return FragmentedArray (
792- _registers = new_registers , _layout = new_layout , _is_signed = self .is_signed ,
793- )
800+ elif dtype_bitwidth == 8 :
801+ assert reg .type .shape == [4 ]
802+ # The vector is 32-bits, so we just shuffle the whole thing and
803+ # use prmt to blend it with the local register.
804+ exchanged = utils .shfl_bfly (reg , 1 )
805+ # Consider lanes 0 and 1, because the situation is symmetric for
806+ # each pair. If we feed reg[lane] and exchanged[lane] (which is
807+ # really the same as reg of the other lane) to prmt, we can index
808+ # the elements of the result using the following indices:
809+ # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
810+ # prmt[0]: 0 1 2 3 4 5 6 7
811+ # prmt[1]: 4 5 6 7 0 1 2 3
812+ # The expected outputs and their respective permutations are:
813+ # out[0]: 0 1 8 9 out[1]: 2 3 10 11
814+ # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
815+ # Note that the patterns still need to be flipped, since we listed
816+ # bytes with LSB on the left, which is the opposite of how the
817+ # numeric constants are spelled in Python (LSB on the right).
818+ perm = arith .select (is_even , c (0x5410 ), c (0x3276 ))
819+ blend = utils .prmt (reg , exchanged , perm )
820+ for i in range (2 ):
821+ reg = utils .vector_slice (blend , slice (i * 2 , i * 2 + 2 ))
822+ new_registers [(idx [0 ], idx [1 ] * 2 + i , * idx [2 :- 1 ])] = reg
823+ else :
824+ assert dtype_bitwidth == 4
825+ assert reg .type .shape == [8 ] # We paired up the registers above.
826+ exchanged = utils .shfl_bfly (reg , 1 )
827+ # See comment above for a more complete explanation.
828+ # reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
829+ # prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
830+ # prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
831+ # The expected outputs and their respective permutations are:
832+ # out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
833+ # prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
834+ perm = arith .select (is_even , c (0x6240 ), c (0x3715 ))
835+ blend = utils .prmt (reg , exchanged , perm )
836+ for i in range (4 ):
837+ reg = utils .vector_slice (blend , slice (i * 2 , i * 2 + 2 ))
838+ new_registers [(idx [0 ], idx [1 ] * 4 + i , * idx [2 :- 1 ])] = reg
839+ assert all (r is not None for r in new_registers )
840+ return FragmentedArray (
841+ _registers = new_registers , _layout = new_layout , _is_signed = self .is_signed ,
842+ )
843+ if (
844+ self .layout == WGMMA_LAYOUT_UPCAST_4X
845+ and new_layout == WGMMA_LAYOUT_UPCAST_2X
846+ and utils .bitwidth (self .mlir_dtype ) == 4
847+ ):
848+ assert shape [0 ] % 64 == 0 # Should be implied by the layout
849+ assert shape [1 ] % 32 == 0 # Should be implied by the layout
850+ new_registers = np .empty (new_layout .registers_shape (shape ), dtype = object )
851+ i32 = ir .IntegerType .get_signless (32 )
852+ c = lambda x : arith .constant (i32 , x )
853+ is_01 = arith .cmpi (
854+ arith .CmpIPredicate .ult , arith .remui (utils .thread_idx (), c (4 )), c (2 )
855+ )
856+ for idx , reg in np .ndenumerate (self .registers ):
857+ assert ir .VectorType (reg .type ).shape == [8 ]
858+ # The vector is 32-bits, so we just shuffle the whole thing and
859+ # use prmt to blend it with the local register.
860+ exchanged = utils .shfl_bfly (reg , 2 )
861+ # See comments above for conventions. Here we exchange data between
862+ # threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
863+ # reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
864+ # prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
865+ # prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
866+ # The expected outputs and their respective permutations are:
867+ # out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
868+ # prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
869+ perm = arith .select (is_01 , c (0x5410 ), c (0x3276 ))
870+ blend = utils .prmt (reg , exchanged , perm )
871+ for i in range (2 ):
872+ reg = utils .vector_slice (blend , slice (i * 4 , i * 4 + 4 ))
873+ new_registers [(idx [0 ], idx [1 ] * 2 + i , * idx [2 :- 1 ])] = reg
874+ assert all (r is not None for r in new_registers )
875+ return FragmentedArray (
876+ _registers = new_registers , _layout = new_layout , _is_signed = self .is_signed ,
877+ )
878+ if self .layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT :
879+ return self .to_layout (WGMMA_LAYOUT_UPCAST_2X ).to_layout (new_layout )
794880 if not isinstance (self .layout , WGSplatFragLayout ):
795881 raise NotImplementedError (
796882 f"Cannot convert from { self .layout } to { new_layout } "
@@ -1288,7 +1374,9 @@ def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
12881374 int_ty = ir .IntegerType .get_signless (group_size * 4 )
12891375 while vector_len - offset >= group_size :
12901376 reg_slice = utils .vector_slice (reg , slice (offset , offset + group_size ))
1291- reg_slice_int = arith .extsi (i32 , utils .bitcast (reg_slice , int_ty ))
1377+ reg_slice_int = utils .bitcast (reg_slice , int_ty )
1378+ if int_ty != i32 :
1379+ reg_slice_int = arith .extsi (i32 , reg_slice_int )
12921380 reg_slice_int_shr = arith .shrui (reg_slice_int , c (4 , i32 ))
12931381 out_int_regs .extend (
12941382 upcast_to_bf16 (reg_slice_int , reg_slice_int_shr , part = part )
0 commit comments