@@ -3315,16 +3315,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
33153315 assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx
33163316
33173317
3318- @pytest .mark .parametrize ("M, N, K, col_a, col_b, type_a, type_b, num_warps" , [
3319- (M , N , K , col_a , col_b , type_a , type_b , 4 )
3320- for M , N , K in itertools .product ([32 , 64 , 128 ], [32 , 64 , 128 ], [64 , 128 ])
3321- for col_a , col_b in itertools .product ([True , False ], repeat = 2 )
3322- # We don't test e5m2 as its range + the uniform sampling overflows easily
3323- # Tested locally and it works fine other than for ~10 entries out of 10_000
3324- # which are of the size of 10**30
3325- for type_a in ["e2m1" , "e4m3" ]
3326- for type_b in ["e4m3" ]
3327- ])
3318+ @pytest .mark .parametrize ("M, N, K, col_a, col_b, type_a, type_b, num_warps" ,
3319+ [(M , N , K , col_a , col_b , type_a , type_b , 4 )
3320+ for M , N , K in itertools .product ([32 , 64 , 128 ], [32 , 64 , 128 ], [64 , 128 ])
3321+ for col_a , col_b in itertools .product ([True , False ], repeat = 2 )
3322+ for type_a in ["e2m1" , "e4m3" , "e5m2" ]
3323+ for type_b in ["e4m3" , "e5m2" ]])
33283324def test_scaled_dot (M , N , K , col_a , col_b , type_a , type_b , num_warps , device ):
33293325 if not is_cuda ():
33303326 pytest .skip ("scaled_dot only supported on CUDA" )
@@ -3355,7 +3351,7 @@ def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, s
33553351 a_scale = tl .load (scale_a_ptr )
33563352 c = tl .dot_scaled (a , a_scale , type_a , b , None , type_b )
33573353 out_ptr = out + tl .arange (0 , BLOCK_M )[:, None ] * BLOCK_N + tl .arange (0 , BLOCK_N )[None , :]
3358- tl .store (out_ptr , c )
3354+ tl .store (out_ptr , c . to ( tl . bfloat16 ) )
33593355
33603356 @triton .jit
33613357 def mxfp_to_bf16_kernel (
@@ -3431,7 +3427,6 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
34313427 type_fp8_y = {"e4m3" : torch .float8_e4m3fn , "e5m2" : torch .float8_e5m2 }[type_y ]
34323428
34333429 comp_dtype = torch .bfloat16
3434- out_dtype = torch .float32
34353430
34363431 x = x .contiguous ()
34373432 x_upcast = x .new_empty (scale .shape [:- 1 ] + (32 * scale .shape [- 1 ], ), dtype = comp_dtype )
@@ -3440,42 +3435,65 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
34403435 BLOCK_SIZE = 512
34413436 grid = ((N + BLOCK_SIZE - 1 ) // BLOCK_SIZE , )
34423437 mxfp_to_bf16_kernel [grid ](x , scale , x_upcast , scale .numel (), e_bits , m_bits , BLOCK_SIZE , num_warps = num_warps )
3438+ assert x_upcast .isfinite ().all ()
34433439
34443440 y_upcast = y .view (type_fp8_y ).to (comp_dtype )
3445- return torch .matmul (x_upcast .to (out_dtype ), y_upcast .to (out_dtype ))
3441+
3442+ class AccumulateInFp32 :
3443+
3444+ def __enter__ (self ):
3445+ self .prev_value = torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction
3446+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = False
3447+
3448+ def __exit__ (self , exc_type , exc_val , exc_tb ):
3449+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = self .prev_value
3450+
3451+ with AccumulateInFp32 ():
3452+ return torch .matmul (x_upcast .to (comp_dtype ), y_upcast .to (comp_dtype ))
34463453
34473454 torch .manual_seed (0 )
34483455
3449- def create_uint8 (shape , col_major = False ):
3456+ def create_uint8 (shape , col_major = False , max_val = 255 ):
34503457 if col_major :
34513458 shape = shape [:- 2 ] + (shape [- 1 ], shape [- 2 ])
3452- ret = torch .randint (1 << 8 , shape , dtype = torch .uint8 , device = device )
3459+ ret = torch .randint (max_val + 1 , shape , dtype = torch .uint8 , device = device )
34533460 if col_major :
34543461 ret = ret .mT
34553462 return ret
34563463
34573464 DIV_FACTOR = 2 if type_a == "e2m1" else 1
34583465 x = create_uint8 ((M , K // DIV_FACTOR ), col_major = col_a )
34593466 y = create_uint8 ((K , N ), col_major = col_b )
3460- scale_x = create_uint8 ((M , K // 32 ))
34613467
3462- z = x .new_empty ((M , N ), dtype = torch .float32 )
3468+ # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
3469+ # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow
3470+ m_bytes = int (type_a [1 ])
3471+ bias_type_a = 1 << (m_bytes - 1 ) - 1
3472+ max_exponent_type_a = (1 << m_bytes ) - 1 - bias_type_a
3473+ scale_x = create_uint8 ((M , K // 32 ), max_val = 255 - max_exponent_type_a - 64 )
3474+
3475+ def make_finite (x , dtype ):
3476+ # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
3477+ # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme)
3478+ if dtype not in ("e5m2" , "e4m3" ):
3479+ return x
3480+ mask = 0x7C if dtype == "e5m2" else 0x7F
3481+ finite = torch .arange (x .numel (), device = device , dtype = torch .uint8 ).reshape_as (x ) % mask
3482+ x_finite = torch .where (x & mask == mask , finite | (0x80 & x ), x )
3483+ x .copy_ (x_finite )
3484+ return x
3485+
3486+ x = make_finite (x , type_a )
3487+ y = make_finite (y , type_b )
3488+
3489+ z = x .new_empty ((M , N ), dtype = torch .bfloat16 )
34633490 pgm = dot_scale_kernel [(1 , )](x , * x .stride (), scale_x , y , * y .stride (), z , M , N , K , type_a , type_b ,
34643491 num_warps = num_warps )
34653492
34663493 z_ref = dot_scale_ref (x , scale_x , y , type_a , type_b )
34673494
3468- # dot_scale_ref computes the result in higher precision
3469- # so we equalise all the non-finite values
3470- # This also fixes a bug in our upcasting from e5m2 to bf16 where inf is not preserved
3471- non_finite_z = ~ z .isfinite ()
3472- z_ref [non_finite_z ] = z [non_finite_z ]
3473- non_finite_ref = ~ z_ref .isfinite ()
3474- z [non_finite_ref ] = z_ref [non_finite_ref ]
3475-
3476- # generous rtol set because the ref is more precise than the fused
3477- # (computes in higher dtype) and we are sampling the whole range of floats
3478- torch .testing .assert_close (z , z_ref , equal_nan = True , atol = 1e-5 , rtol = 1e-2 )
3495+ # generous rtol as we are sampling the whole range of floats
3496+ torch .testing .assert_close (z , z_ref , atol = 1e-5 , rtol = 1e-2 )
34793497
34803498 # make sure ld/st are vectorized
34813499 ptx = pgm .asm ['ptx' ]
0 commit comments