@@ -3424,6 +3424,10 @@ def convert_fp8_to_fp32(x, device, dtype_str):
34243424 return torch .tensor (x , device = device ).view (torch .float8_e4m3fn ).to (torch .float32 )
34253425 elif dtype_str == 'float8e5' :
34263426 return torch .tensor (x , device = device ).view (torch .float8_e5m2 ).to (torch .float32 )
3427+ elif dtype_str == 'float8e4b8' :
3428+ return torch .tensor (x , device = device ).view (torch .float8_e4m3fnuz ).to (torch .float32 )
3429+ elif dtype_str == 'float8e5b16' :
3430+ return torch .tensor (x , device = device ).view (torch .float8_e5m2fnuz ).to (torch .float32 )
34273431 assert "Unsupported float8 dtype"
34283432
34293433
@@ -3553,12 +3557,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
35533557 if capability [0 ] < 9 and in_dtype == 'float8e4nv' :
35543558 pytest .skip ("float8e4nv not supported on sm <= 80" )
35553559
3556- if is_hip () and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5' ):
3557- pytest .skip ("float8e4nv and float8e5 not supported on HIP" )
3558- if is_hip () and not ((input_precision == "ieee" ) or (input_precision == "tf32" and is_hip_mi300 ())):
3559- pytest .skip (f"{ input_precision } not supported on HIP" )
3560- if is_hip () and (kpack == 2 and in_dtype == 'int8' and K < 64 ):
3561- pytest .skip ("kpack too large for K" )
3560+ if is_hip ():
3561+ if in_dtype in ("float8e5" , "float8e4nv" ) and not is_hip_mi350 ():
3562+ pytest .skip (f"{ in_dtype } only supported on mi350" )
3563+ if in_dtype in ("float8e5b16" , "float8e4b8" ) and not is_hip_mi300 ():
3564+ pytest .skip (f"{ in_dtype } only supported on mi300" )
3565+ if not ((input_precision == "ieee" ) or (input_precision == "tf32" and is_hip_mi300 ())):
3566+ pytest .skip (f"{ input_precision } not supported on HIP" )
3567+ if kpack == 2 and in_dtype == 'int8' and K < 64 :
3568+ pytest .skip ("kpack too large for K" )
35623569 if not is_hip () and kpack == 2 :
35633570 pytest .skip ("Skip duplicated tests on nv path" )
35643571
@@ -3686,6 +3693,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
36863693 z_fp8 = torch .tensor (z_ref , dtype = torch .float8_e4m3fn )
36873694 elif in_dtype == 'float8e5' :
36883695 z_fp8 = torch .tensor (z_ref , dtype = torch .float8_e5m2 )
3696+ elif in_dtype == 'float8e4b8' :
3697+ z_fp8 = torch .tensor (z_ref , dtype = torch .float8_e4m3fnuz )
3698+ elif in_dtype == 'float8e5b16' :
3699+ z_fp8 = torch .tensor (z_ref , dtype = torch .float8_e5m2fnuz )
36893700 else :
36903701 assert "Unsupported float8 dtype"
36913702 z_ref = to_numpy (z_fp8 .to (torch .float32 ))
@@ -6411,7 +6422,8 @@ def matmul_kernel( #
64116422@pytest .mark .parametrize ("M, N, K" , [(128 , 256 , 256 )])
64126423@pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 256 , 128 ), (64 , 64 , 64 )])
64136424@pytest .mark .parametrize (
6414- "in_type_str" , ['float8e5' , 'float8e5b16' , 'float8e4b8' ] if is_hip () else ['float8e5' , 'float8e4nv' , 'float8e4b15' ])
6425+ "in_type_str" ,
6426+ ['float8e5' , 'float8e5b16' , 'float8e4b8' , 'float8e4nv' ] if is_hip () else ['float8e5' , 'float8e4nv' , 'float8e4b15' ])
64156427@pytest .mark .parametrize ("low_precision_acc" , [0 , 32 , 64 , 128 ])
64166428def test_dot_max_num_imprecise_acc (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , in_type_str , low_precision_acc , device ):
64176429 num_stages = 3
@@ -6423,6 +6435,8 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
64236435 num_stages = 2
64246436 if in_type_str in ("float8e5b16" , "float8e4b8" ) and not is_hip_mi300 ():
64256437 pytest .skip (f"{ in_type_str } only supported on mi300" )
6438+ if in_type_str in ("float8e5" , "float8e4nv" ) and not is_hip_mi350 ():
6439+ pytest .skip (f"{ in_type_str } only supported on mi350" )
64266440
64276441 check_type_supported (in_type_str , device )
64286442 A = numpy_random ((M , K ), dtype_str = in_type_str )
0 commit comments