@@ -3367,48 +3367,55 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
33673367 assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx
33683368
33693369
3370- @pytest .mark .parametrize ("M, N, K, col_a, col_b, type_a, type_b , num_warps, mma, kpack" ,
3371- [(M , N , K , col_a , col_b , type_a , type_b , 4 , mma , kpack )
3370+ @pytest .mark .parametrize ("M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type , num_warps, mma, kpack" ,
3371+ [(M , N , K , col_a , col_b , rhs_scale , normal_type , mxfp_type , 4 , mma , kpack )
33723372 for M , N , K in itertools .product ([32 , 64 , 128 ], [32 , 64 , 128 ], [64 , 128 ])
33733373 for col_a , col_b in itertools .product ([True , False ], repeat = 2 )
3374- for type_a in ["e2m1" , "e4m3" , "e5m2" ]
3375- for type_b in ["e4m3" , "e5m2" , "bf16" ]
3374+ for rhs_scale in [False , True ]
3375+ for normal_type in ["e2m1" , "e4m3" , "e5m2" ]
3376+ for mxfp_type in ["e4m3" , "e5m2" , "bf16" ]
33763377 for mma in ([32 , 16 ] if is_hip () else [16 ])
33773378 for kpack in ([1 , 2 ] if is_hip () else [1 ])])
3378- def test_scaled_dot (M , N , K , col_a , col_b , type_a , type_b , num_warps , mma , kpack , device ):
3379+ def test_scaled_dot (M , N , K , col_a , col_b , rhs_scale , normal_type , mxfp_type , num_warps , mma , kpack , device ):
33793380 if is_cuda ():
33803381 cc = torch .cuda .get_device_capability ()
33813382 if cc < (8 , 9 ):
33823383 pytest .skip ("float8e4nv not supported on CUDA < 8.9" )
33833384 if is_hip ():
3385+ if rhs_scale :
3386+ pytest .skip ("scales on rhs not yet support for HIP" )
33843387 if not is_hip_cdna ():
33853388 pytest .skip ("scaled_dot only implemented for HIP CDNA" )
3386- if "e4m3" in (type_a , type_b ) and not is_hip_mi300 ():
3387- pytest .skip (f"scaled_dot({ type_a } , { type_b } ) only implemented for MI300" )
3389+ if "e4m3" in (normal_type , mxfp_type ) and not is_hip_mi300 ():
3390+ pytest .skip (f"scaled_dot({ normal_type } , { mxfp_type } ) only implemented for MI300" )
33883391 if mma == 16 and K == 64 :
33893392 pytest .skip (f"K == { K } too small for mfma { mma } in scaled_dot" )
33903393
33913394 @triton .jit
3392- def dot_scale_kernel (a_base , stride_a0 , stride_a1 , a_scale , b_base , stride_b0 , stride_b1 , out ,
3395+ def dot_scale_kernel (a_base , stride_a0 , stride_a1 , a_scale , b_base , stride_b0 , stride_b1 , b_scale , out ,
33933396 BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , type_a : tl .constexpr ,
33943397 type_b : tl .constexpr ):
3395- tl .static_assert ((type_b == "e4m3" or type_b == "e5m2" ) or type_b == "bf16" , "type_b must be fp8 or bf16" )
3396- IS_FP8 : tl .constexpr = type_a == "e4m3" or type_a == "e5m2"
3397- DIV_FACTOR : tl .constexpr = 1 if IS_FP8 else 2
3398- PACKED_BLOCK_K_A : tl .constexpr = BLOCK_K // DIV_FACTOR
3399- PACKED_BLOCK_K_B : tl .constexpr = BLOCK_K
3398+ DIV_FACTOR_A : tl .constexpr = 2 if type_a == "e2m1" else 1
3399+ DIV_FACTOR_B : tl .constexpr = 2 if type_b == "e2m1" else 1
3400+ PACKED_BLOCK_K_A : tl .constexpr = BLOCK_K // DIV_FACTOR_A
3401+ PACKED_BLOCK_K_B : tl .constexpr = BLOCK_K // DIV_FACTOR_B
34003402 a_ptr = a_base + tl .arange (0 , BLOCK_M )[:, None ] * stride_a0 + tl .arange (0 ,
34013403 PACKED_BLOCK_K_A )[None , :] * stride_a1
34023404 b_ptr = b_base + tl .arange (0 , PACKED_BLOCK_K_B )[:, None ] * stride_b0 + tl .arange (0 ,
34033405 BLOCK_N )[None , :] * stride_b1
34043406
3405- SCALE_BLOCK_K : tl .constexpr = BLOCK_K // 32
3406- scale_a_ptr = a_scale + tl .arange (0 , BLOCK_M )[:, None ] * SCALE_BLOCK_K + tl .arange (0 , SCALE_BLOCK_K )[None , :]
3407-
34083407 a = tl .load (a_ptr )
34093408 b = tl .load (b_ptr )
3410- a_scale = tl .load (scale_a_ptr )
3411- c = tl .dot_scaled (a , a_scale , type_a , b , None , type_b )
3409+ SCALE_BLOCK_K : tl .constexpr = BLOCK_K // 32
3410+ if a_scale is not None :
3411+ scale_a_ptr = a_scale + tl .arange (0 , BLOCK_M )[:, None ] * SCALE_BLOCK_K + tl .arange (0 ,
3412+ SCALE_BLOCK_K )[None , :]
3413+ a_scale = tl .load (scale_a_ptr )
3414+ if b_scale is not None :
3415+ scale_b_ptr = b_scale + tl .arange (0 , BLOCK_N )[:, None ] * SCALE_BLOCK_K + tl .arange (0 ,
3416+ SCALE_BLOCK_K )[None , :]
3417+ b_scale = tl .load (scale_b_ptr )
3418+ c = tl .dot_scaled (a , a_scale , type_a , b , b_scale , type_b )
34123419 out_ptr = out + tl .arange (0 , BLOCK_M )[:, None ] * BLOCK_N + tl .arange (0 , BLOCK_N )[None , :]
34133420 tl .store (out_ptr , c .to (tl .bfloat16 ))
34143421
@@ -3481,22 +3488,31 @@ def mxfp_to_bf16_kernel(
34813488 offsets = tl .program_id (0 ) * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
34823489 tl .store (mxfp_ptr + offsets , tl .ravel (mxfp ), mask = offsets < N * 32 )
34833490
3484- def dot_scale_ref (x , scale , y , type_x , type_y ):
3485- e_bits , m_bits = {"e2m1" : (2 , 1 ), "e4m3" : (4 , 3 ), "e5m2" : (5 , 2 )}[type_x ]
3486- type_y = {"e4m3" : torch .float8_e4m3fn , "e5m2" : torch .float8_e5m2 , "bf16" : torch .bfloat16 }[type_y ]
3487-
3488- comp_dtype = torch .bfloat16
3489-
3490- x = x .contiguous ()
3491- x_upcast = x .new_empty (scale .shape [:- 1 ] + (32 * scale .shape [- 1 ], ), dtype = comp_dtype )
3492-
3493- N = x_upcast .numel ()
3494- BLOCK_SIZE = 512
3495- grid = ((N + BLOCK_SIZE - 1 ) // BLOCK_SIZE , )
3496- mxfp_to_bf16_kernel [grid ](x , scale , x_upcast , scale .numel (), e_bits , m_bits , BLOCK_SIZE , num_warps = num_warps )
3497- assert x_upcast .isfinite ().all ()
3498-
3499- y_upcast = y .view (type_y ).to (comp_dtype )
3491+ def dot_scale_ref (x , scale_x , y , scale_y , type_x , type_y ):
3492+
3493+ def upcast (v , scale , type , transposed ):
3494+ comp_dtype = torch .bfloat16
3495+ if scale is None :
3496+ type = {"e4m3" : torch .float8_e4m3fn , "e5m2" : torch .float8_e5m2 , "bf16" : torch .bfloat16 }[type ]
3497+ return v .view (type ).to (comp_dtype )
3498+ e_bits , m_bits = {"e2m1" : (2 , 1 ), "e4m3" : (4 , 3 ), "e5m2" : (5 , 2 )}[type ]
3499+ # Packing is always on the K dimension so we transpose before upcasting then transpose back.
3500+ if transposed :
3501+ v = v .mT .contiguous ()
3502+ v = v .contiguous ()
3503+ v_upcast = v .new_empty (scale .shape [:- 1 ] + (32 * scale .shape [- 1 ], ), dtype = comp_dtype )
3504+ N = v_upcast .numel ()
3505+ BLOCK_SIZE = 512
3506+ grid = ((N + BLOCK_SIZE - 1 ) // BLOCK_SIZE , )
3507+ mxfp_to_bf16_kernel [grid ](v , scale , v_upcast , scale .numel (), e_bits , m_bits , BLOCK_SIZE ,
3508+ num_warps = num_warps )
3509+ assert v_upcast .isfinite ().all ()
3510+ if transposed :
3511+ v_upcast = v_upcast .mT
3512+ return v_upcast
3513+
3514+ x_upcast = upcast (x , scale_x , type_x , False )
3515+ y_upcast = upcast (y , scale_y , type_y , True )
35003516
35013517 class AccumulateInFp32 :
35023518
@@ -3525,13 +3541,22 @@ def make_arg(shape, ty, col_major=False, max_val=255):
35253541 ret = ret .mT
35263542 return ret
35273543
3528- DIV_FACTOR = 2 if type_a == "e2m1" else 1
3529- x = make_arg ((M , K // DIV_FACTOR ), type_a , col_major = col_a )
3530- y = make_arg ((K , N ), type_b , col_major = col_b )
3544+ type_a = normal_type if not rhs_scale else mxfp_type
3545+ type_b = mxfp_type if not rhs_scale else normal_type
3546+
3547+ DIV_FACTOR_A = 2 if type_a == "e2m1" else 1
3548+ DIV_FACTOR_B = 2 if type_b == "e2m1" else 1
3549+ x = make_arg ((M , K // DIV_FACTOR_A ), type_a , col_major = col_a )
3550+ y = make_arg ((K // DIV_FACTOR_B , N ), type_b , col_major = col_b )
35313551
35323552 # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
35333553 # Max scale= 2**15
35343554 scale_x = make_arg ((M , K // 32 ), "e8m0" , max_val = 127 + 15 )
3555+ scale_y = make_arg ((N , K // 32 ), "e8m0" , max_val = 127 + 15 )
3556+ if rhs_scale :
3557+ scale_x = None
3558+ else :
3559+ scale_y = None
35353560
35363561 def make_finite (x , dtype ):
35373562 # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
@@ -3546,16 +3571,14 @@ def make_finite(x, dtype):
35463571
35473572 x = make_finite (x , type_a )
35483573 y = make_finite (y , type_b )
3549-
35503574 kernel_kwargs = {"num_warps" : num_warps }
35513575 if is_hip ():
35523576 kernel_kwargs ["kpack" ] = kpack
35533577 kernel_kwargs ["matrix_instr_nonkdim" ] = mma
35543578 z = x .new_empty ((M , N ), dtype = torch .bfloat16 )
3555- pgm = dot_scale_kernel [(1 , )](x , * x .stride (), scale_x , y , * y .stride (), z , M , N , K , type_a , type_b , ** kernel_kwargs )
3556-
3557- z_ref = dot_scale_ref (x , scale_x , y , type_a , type_b )
3558-
3579+ pgm = dot_scale_kernel [(1 , )](x , * x .stride (), scale_x , y , * y .stride (), scale_y , z , M , N , K , type_a , type_b ,
3580+ ** kernel_kwargs )
3581+ z_ref = dot_scale_ref (x , scale_x , y , scale_y , type_a , type_b )
35593582 # Bigger tolerance for AMD MI200 devices.
35603583 # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values
35613584 # to zero. Detailed info is at:
0 commit comments