@@ -889,21 +889,26 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
889889 if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE ) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE ):
890890 pytest .skip ("Float4 without scale is tested in test_block_scale_fp4" )
891891
892- if B_DATA_TYPE != 'float4' and B_TRANS :
893- pytest .skip (f'No need to transpose B for { B_DATA_TYPE } ' )
894-
895892 if not is_hip () and BLOCK_N == 256 and BLOCK_K == 256 :
896893 NUM_STAGES = 2
897894
898895 torch .manual_seed (42 )
899896
900- def create_operand (dtype : str , size0 : int , size1 : int , k_dim : int , transpose : bool = False ):
897+ def create_operand (dtype : str , size0 : int , size1 : int , k_dim : int , transpose : bool = True ):
901898 if dtype == "float8e5" :
902- v = torch .randint (20 , 40 , (size0 , size1 ), dtype = torch .uint8 ).view (torch .float8_e5m2 ).to (device )
903- v_ref = f8_to_f16 (v .view (torch .float8_e5m2 ), dtype ).to (torch .float32 )
899+ if transpose :
900+ v = torch .randint (20 , 40 , (size0 , size1 ), dtype = torch .uint8 ).view (torch .float8_e5m2 ).to (device )
901+ v_ref = f8_to_f16 (v .view (torch .float8_e5m2 ), dtype ).to (torch .float32 )
902+ else :
903+ v = torch .randint (20 , 40 , (size1 , size0 ), dtype = torch .uint8 ).view (torch .float8_e5m2 ).to (device ).T
904+ v_ref = f8_to_f16 (v .view (torch .float8_e5m2 ).T , dtype ).to (torch .float32 ).T
904905 elif dtype == "float8e4nv" :
905- v = torch .randint (20 , 40 , (size0 , size1 ), dtype = torch .uint8 ).view (torch .float8_e4m3fn ).to (device )
906- v_ref = f8_to_f16 (v .view (torch .float8_e4m3fn ), dtype ).to (torch .float32 )
906+ if transpose :
907+ v = torch .randint (20 , 40 , (size0 , size1 ), dtype = torch .uint8 ).view (torch .float8_e4m3fn ).to (device )
908+ v_ref = f8_to_f16 (v .view (torch .float8_e4m3fn ), dtype ).to (torch .float32 )
909+ else :
910+ v = torch .randint (20 , 40 , (size1 , size0 ), dtype = torch .uint8 ).view (torch .float8_e4m3fn ).to (device ).T
911+ v_ref = f8_to_f16 (v .view (torch .float8_e4m3fn ).T , dtype ).to (torch .float32 ).T
907912 else :
908913 # float4
909914 if transpose :
@@ -921,8 +926,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
921926 a , a_ref = create_operand (A_DATA_TYPE , M , K , 1 )
922927 b , b_ref = create_operand (B_DATA_TYPE , K , N , 0 , B_TRANS )
923928
924- a_scale_mxfp4 = MXScaleTensor (size = (M , (K + 32 - 1 ) // 32 ), device = device ).random (high = 64 .0 )
925- b_scale_mxfp4 = MXScaleTensor (size = (N , (K + 32 - 1 ) // 32 ), device = device ).random (high = 64 .0 )
929+ a_scale_mxfp4 = MXScaleTensor (size = (M , (K + 32 - 1 ) // 32 ), device = device ).random (high = 32 .0 )
930+ b_scale_mxfp4 = MXScaleTensor (size = (N , (K + 32 - 1 ) // 32 ), device = device ).random (high = 32 .0 )
926931 a_scale = a_scale_mxfp4 .data
927932 b_scale = b_scale_mxfp4 .data
928933
0 commit comments