Skip to content

Commit c2fd8e1

Browse files
authored
[AMD] Test transposed B for scaled dot fp8/bf8 types (#6078)
Enabled more tests in test_mxfp8_mxfp4_matmul for the AMD backend.
1 parent 65939a0 commit c2fd8e1

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -889,21 +889,26 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
889889
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):
890890
pytest.skip("Float4 without scale is tested in test_block_scale_fp4")
891891

892-
if B_DATA_TYPE != 'float4' and B_TRANS:
893-
pytest.skip(f'No need to transpose B for {B_DATA_TYPE}')
894-
895892
if not is_hip() and BLOCK_N == 256 and BLOCK_K == 256:
896893
NUM_STAGES = 2
897894

898895
torch.manual_seed(42)
899896

900-
def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bool = False):
897+
def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bool = True):
901898
if dtype == "float8e5":
902-
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
903-
v_ref = f8_to_f16(v.view(torch.float8_e5m2), dtype).to(torch.float32)
899+
if transpose:
900+
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
901+
v_ref = f8_to_f16(v.view(torch.float8_e5m2), dtype).to(torch.float32)
902+
else:
903+
v = torch.randint(20, 40, (size1, size0), dtype=torch.uint8).view(torch.float8_e5m2).to(device).T
904+
v_ref = f8_to_f16(v.view(torch.float8_e5m2).T, dtype).to(torch.float32).T
904905
elif dtype == "float8e4nv":
905-
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device)
906-
v_ref = f8_to_f16(v.view(torch.float8_e4m3fn), dtype).to(torch.float32)
906+
if transpose:
907+
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device)
908+
v_ref = f8_to_f16(v.view(torch.float8_e4m3fn), dtype).to(torch.float32)
909+
else:
910+
v = torch.randint(20, 40, (size1, size0), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device).T
911+
v_ref = f8_to_f16(v.view(torch.float8_e4m3fn).T, dtype).to(torch.float32).T
907912
else:
908913
# float4
909914
if transpose:
@@ -921,8 +926,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
921926
a, a_ref = create_operand(A_DATA_TYPE, M, K, 1)
922927
b, b_ref = create_operand(B_DATA_TYPE, K, N, 0, B_TRANS)
923928

924-
a_scale_mxfp4 = MXScaleTensor(size=(M, (K + 32 - 1) // 32), device=device).random(high=64.0)
925-
b_scale_mxfp4 = MXScaleTensor(size=(N, (K + 32 - 1) // 32), device=device).random(high=64.0)
929+
a_scale_mxfp4 = MXScaleTensor(size=(M, (K + 32 - 1) // 32), device=device).random(high=32.0)
930+
b_scale_mxfp4 = MXScaleTensor(size=(N, (K + 32 - 1) // 32), device=device).random(high=32.0)
926931
a_scale = a_scale_mxfp4.data
927932
b_scale = b_scale_mxfp4.data
928933

0 commit comments

Comments
 (0)