@@ -1032,37 +1032,48 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
10321032 )
10331033 reg_type = self .registers .flat [0 ].type
10341034 is_vector_reg = ir .VectorType .isinstance (reg_type )
1035- reg_shape = tuple (ir .VectorType (reg_type ).shape ) if is_vector_reg else ()
1036- if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2 ,):
1035+ reg_shape = tuple (ir .VectorType (reg_type ).shape ) if is_vector_reg else (1 ,)
1036+ [vector_len ] = reg_shape # This is meant to be a 1D assertion.
1037+ if cur_dtype == i8 and self .is_signed and new_dtype == bf16 and vector_len in {2 , 4 }:
10371038 new_registers = np .empty_like (self .registers )
1038- for idx , reg in np .ndenumerate (self .registers ):
1039- reg_16 = vector .bitcast (ir .VectorType .get ((1 ,), i16 ), reg )
1040- val_16 = llvm .extractelement (reg_16 , c (0 , i32 ))
1039+ def upcast_to_bf16 (reg , high ):
10411040 # We first embed the s8 into a bf16 with the exponent equal to
10421041 # bias + mantissa bits. Then, we zero the msb that didn't fit into the
10431042 # mantissa, zero out all bits other than msb, and subtract the last
10441043 # two values from each other. This takes advantage of the fact that the
10451044 # lsb of the exponent (msb of the second byte) is zero, which allows us
10461045 # to losslesly pack the msb there. When 1, it doubles the value of s2,
10471046 # making the result negative.
1048- new_val_32 = llvm .inline_asm (
1047+ return llvm .inline_asm (
10491048 i32 ,
1050- [val_16 ],
1051- """
1052- {
1049+ [reg ],
1050+ f """
1051+ {{
10531052 .reg .b32 s<3>;
1054- prmt.b32 s0, $1, 0x43, 0x4140;
1053+ prmt.b32 s0, $1, 0x43, { 0x4342 if high else 0x4140 } ;
10551054 and.b32 s1, s0, 0xff7fff7f;
10561055 and.b32 s2, s0, 0xff80ff80;
10571056 sub.bf16x2 $0, s1, s2;
1058- }
1057+ }}
10591058 """ ,
10601059 "=r,r" ,
10611060 )
1062- new_vec = llvm .mlir_undef (ir .VectorType .get ((1 ,), i32 ))
1063- new_vec = llvm .insertelement (new_vec , new_val_32 , c (0 , i32 ))
1061+ empty_vec_32 = llvm .mlir_undef (ir .VectorType .get ((vector_len // 2 ,), i32 ))
1062+ for idx , reg in np .ndenumerate (self .registers ):
1063+ if vector_len == 2 :
1064+ reg_16 = vector .bitcast (ir .VectorType .get ((1 ,), i16 ), reg )
1065+ new_reg_32 = upcast_to_bf16 (reg_16 , high = False )
1066+ new_vec_32 = llvm .insertelement (empty_vec_32 , new_reg_32 , c (0 , i32 ))
1067+ elif vector_len == 4 :
1068+ reg_32 = vector .bitcast (ir .VectorType .get ((1 ,), i32 ), reg )
1069+ low = upcast_to_bf16 (reg_32 , high = False )
1070+ high = upcast_to_bf16 (reg_32 , high = True )
1071+ new_vec_32 = llvm .insertelement (empty_vec_32 , low , c (0 , i32 ))
1072+ new_vec_32 = llvm .insertelement (new_vec_32 , high , c (1 , i32 ))
1073+ else :
1074+ raise NotImplementedError (vector_len )
10641075 new_registers [idx ] = vector .bitcast (
1065- ir .VectorType .get ((2 ,), new_dtype ), new_vec
1076+ ir .VectorType .get ((vector_len ,), new_dtype ), new_vec_32
10661077 )
10671078 return FragmentedArray (
10681079 _registers = new_registers , _layout = self .layout , _is_signed = is_signed
0 commit comments