@@ -1590,6 +1590,20 @@ def _bitcast_to_fp_type(self, val: TensorTy, float_format: str):
15901590 assert val .dtype == unsigned_ty , f"Unexpected dtype for { float_format } . Got { val .dtype } "
15911591 return self .bitcast (val , triton_ty )
15921592
1593+ def verify_scaled_shape (self , M , N , K , lhs_scale , rhs_scale ):
1594+ if lhs_scale is not None :
1595+ scale_factor = 16 if lhs_scale .dtype .is_fp8e4nv () else 32
1596+ lhs_scale_shape = lhs_scale .type .shape
1597+ assert lhs_scale_shape == [
1598+ M , K // scale_factor
1599+ ], f"lhs_scale must be a tensor of shape [{ M } , { K // scale_factor } ]. Got { lhs_scale_shape } "
1600+ if rhs_scale is not None :
1601+ scale_factor = 16 if rhs_scale .dtype .is_fp8e4nv () else 32
1602+ rhs_scale_shape = rhs_scale .type .shape
1603+ assert rhs_scale_shape == [
1604+ N , K // scale_factor
1605+ ], f"rhs_scale must be a tensor of shape [{ N } , { K // scale_factor } ]. Got { rhs_scale_shape } "
1606+
15931607 def dot_scaled (self , lhs : TensorTy , lhs_scale : TensorTy , lhs_format : str , rhs : TensorTy ,
15941608 rhs_scale : Optional [TensorTy ], rhs_format : str , acc : TensorTy | None , fast_math : bool ,
15951609 lhs_k_pack : bool , rhs_k_pack : bool , out_dtype : tl .dtype ) -> TensorTy :
@@ -1621,8 +1635,11 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T
16211635 assert PACKED_B_DIM == PACKED_A_DIM , f"Reduction dimension should pack the same number of elements; (lhs: { lhs .shape } vs rhs: { rhs .shape } )"
16221636 #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
16231637 B = lhs .type .shape [0 ] if lhs_rank == 3 else None
1638+ K = K_LHS
16241639 if not lhs_k_pack :
16251640 M = M * PACKED_A
1641+ else :
1642+ K = K * PACKED_A
16261643 if not rhs_k_pack :
16271644 N = N * PACKED_B
16281645 ret_ty = tl .block_type (out_dtype , [B , M , N ] if B else [M , N ])
@@ -1634,6 +1651,8 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T
16341651 assert acc .type .shape == ret_ty .shape and acc .type .element_ty == out_dtype
16351652 rhs_scale_handle = None if rhs_scale_is_none else rhs_scale .handle
16361653 lhs_scale_handle = None if lhs_scale_is_none else lhs_scale .handle
1654+ self .verify_scaled_shape (M , N , K , None if lhs_scale_is_none else lhs_scale ,
1655+ None if rhs_scale_is_none else rhs_scale )
16371656 return self .tensor (
16381657 self .builder .create_dot_scaled (lhs .handle , lhs_scale_handle , lhs_format_enum , rhs .handle , rhs_scale_handle ,
16391658 rhs_format_enum , fast_math , lhs_k_pack , rhs_k_pack , acc_handle ), ret_ty )
0 commit comments