Skip to content

Commit 225baf2

Browse files
authored
Add e2e tests for F8E5M2FNUZ and F8E4M3FNUZ data-tiled MFMA on CDNA3 (#18888)
Signed-off-by: Benoit Jacob <[email protected]>
1 parent 4ad834b commit 225baf2

File tree

5 files changed

+209
-5
lines changed

5 files changed

+209
-5
lines changed

tests/e2e/matmul/CMakeLists.txt

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,7 @@ iree_generated_e2e_runner_test(
15261526

15271527
iree_generated_e2e_runner_test(
15281528
NAME
1529-
e2e_matmul_rocm_f16_large_cdna3_mfma_data_tiled
1529+
e2e_matmul_rocm_f16_cdna3_mfma_data_tiled
15301530
TEST_TYPE
15311531
matmul
15321532
GENERATOR
@@ -1555,7 +1555,7 @@ iree_generated_e2e_runner_test(
15551555

15561556
iree_generated_e2e_runner_test(
15571557
NAME
1558-
e2e_matmul_rocm_i8_large_cdna3_mfma_data_tiled
1558+
e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
15591559
TEST_TYPE
15601560
matmul
15611561
GENERATOR
@@ -1584,7 +1584,7 @@ iree_generated_e2e_runner_test(
15841584

15851585
iree_generated_e2e_runner_test(
15861586
NAME
1587-
e2e_matmul_rocm_f32_large_cdna3_mfma_data_tiled
1587+
e2e_matmul_rocm_f32_cdna3_mfma_data_tiled
15881588
TEST_TYPE
15891589
matmul
15901590
GENERATOR
@@ -1611,6 +1611,64 @@ iree_generated_e2e_runner_test(
16111611
"requires-gpu-cdna3"
16121612
)
16131613

1614+
iree_generated_e2e_runner_test(
1615+
NAME
1616+
e2e_matmul_rocm_f8E5M2FNUZ_cdna3_mfma_data_tiled
1617+
TEST_TYPE
1618+
matmul
1619+
GENERATOR
1620+
"generate_e2e_matmul_tests.py"
1621+
GENERATOR_ARGS
1622+
"--lhs_rhs_type=f8E5M2FNUZ"
1623+
"--acc_type=f32"
1624+
TEST_RUNNER
1625+
iree_tools_testing_e2e_iree-e2e-matmul-test
1626+
TARGET_BACKENDS
1627+
"rocm"
1628+
DRIVERS
1629+
"hip"
1630+
COMPILER_FLAGS
1631+
${IREE_HIP_TEST_COMPILER_FLAGS}
1632+
"--iree-opt-data-tiling"
1633+
"--iree-global-opt-experimental-rocm-data-tiling"
1634+
"--iree-global-opt-enable-early-materialization=true"
1635+
LABELS
1636+
"noasan"
1637+
"nomsan"
1638+
"notsan"
1639+
"noubsan"
1640+
"requires-gpu-cdna3"
1641+
)
1642+
1643+
iree_generated_e2e_runner_test(
1644+
NAME
1645+
e2e_matmul_rocm_f8E4M3FNUZ_cdna3_mfma_data_tiled
1646+
TEST_TYPE
1647+
matmul
1648+
GENERATOR
1649+
"generate_e2e_matmul_tests.py"
1650+
GENERATOR_ARGS
1651+
"--lhs_rhs_type=f8E4M3FNUZ"
1652+
"--acc_type=f32"
1653+
TEST_RUNNER
1654+
iree_tools_testing_e2e_iree-e2e-matmul-test
1655+
TARGET_BACKENDS
1656+
"rocm"
1657+
DRIVERS
1658+
"hip"
1659+
COMPILER_FLAGS
1660+
${IREE_HIP_TEST_COMPILER_FLAGS}
1661+
"--iree-opt-data-tiling"
1662+
"--iree-global-opt-experimental-rocm-data-tiling"
1663+
"--iree-global-opt-enable-early-materialization=true"
1664+
LABELS
1665+
"noasan"
1666+
"nomsan"
1667+
"notsan"
1668+
"noubsan"
1669+
"requires-gpu-cdna3"
1670+
)
1671+
16141672
endif()
16151673

16161674
elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")

tests/e2e/matmul/generate_e2e_matmul_tests.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ class MatrixElemTypeId(enum.Enum):
2727
I32 = "i32"
2828
F32 = "f32"
2929
F16 = "f16"
30-
F8E4M3FNUZ = "f8E4M3FNUZ"
3130
BF16 = "bf16"
31+
F8E5M2 = "f8E5M2"
32+
F8E4M3 = "f8E4M3"
33+
F8E5M2FNUZ = "f8E5M2FNUZ"
34+
F8E4M3FNUZ = "f8E4M3FNUZ"
3235

3336

3437
# Enumerates of the collections of shapes that we can generate tests for.
@@ -905,7 +908,17 @@ def parse_arguments():
905908
parser.add_argument(
906909
"--lhs_rhs_type",
907910
type=str,
908-
choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"],
911+
choices=[
912+
"i32",
913+
"i8",
914+
"f32",
915+
"f16",
916+
"bf16",
917+
"f8E5M2",
918+
"f8E4M3",
919+
"f8E5M2FNUZ",
920+
"f8E4M3FNUZ",
921+
],
909922
help="Numeric type of input matrices",
910923
required=True,
911924
)
@@ -999,6 +1012,12 @@ def write_calls_file(functions, calls, filename, requirements):
9991012
def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
10001013
if acc_type != MatrixElemTypeId.NONE:
10011014
return acc_type
1015+
if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
1016+
return MatrixElemTypeId.F32
1017+
if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
1018+
return MatrixElemTypeId.F32
1019+
if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
1020+
return MatrixElemTypeId.F32
10021021
if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
10031022
return MatrixElemTypeId.F32
10041023
if lhs_rhs_type == MatrixElemTypeId.I8:

tools/testing/e2e/iree-e2e-matmul-test.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,29 @@ static void reference_matmul_bf16_bf16_f32_f32(
128128
result_data[n + m * n_size] = acc;
129129
}
130130

131+
#define REFERENCE_MATMUL_F8(LHSTYPE, RHSTYPE) \
132+
static void reference_matmul_##LHSTYPE##_##RHSTYPE##_f32_f32( \
133+
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \
134+
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \
135+
iree_hal_element_type_t acc_type, bool transpose_rhs, \
136+
const uint8_t* lhs_data, const uint8_t* rhs_data, const float* acc_data, \
137+
float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { \
138+
float acc = acc_data ? acc_data[n + m * n_size] : 0; \
139+
for (iree_hal_dim_t k = 0; k < k_size; ++k) { \
140+
float lhs_float = \
141+
iree_math_##LHSTYPE##_to_f32(lhs_data[k + m * k_size]); \
142+
float rhs_float = iree_math_##RHSTYPE##_to_f32( \
143+
rhs_data[transpose_rhs ? k + n * k_size : n + k * n_size]); \
144+
acc += lhs_float * rhs_float; \
145+
} \
146+
result_data[n + m * n_size] = acc; \
147+
}
148+
149+
REFERENCE_MATMUL_F8(f8e5m2, f8e5m2)
150+
REFERENCE_MATMUL_F8(f8e4m3, f8e4m3)
151+
REFERENCE_MATMUL_F8(f8e5m2fnuz, f8e5m2fnuz)
152+
REFERENCE_MATMUL_F8(f8e4m3fnuz, f8e4m3fnuz)
153+
131154
// Helper for reference_matmul.
132155
// Computes one element in the result matrix.
133156
static iree_status_t reference_matmul_element(
@@ -185,6 +208,34 @@ static iree_status_t reference_matmul_element(
185208
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
186209
(const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
187210
(const float*)acc_data, (float*)result_data, m, n);
211+
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 &&
212+
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 &&
213+
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
214+
reference_matmul_f8e5m2_f8e5m2_f32_f32(
215+
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
216+
(const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
217+
(const float*)acc_data, (float*)result_data, m, n);
218+
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 &&
219+
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 &&
220+
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
221+
reference_matmul_f8e4m3_f8e4m3_f32_f32(
222+
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
223+
(const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
224+
(const float*)acc_data, (float*)result_data, m, n);
225+
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ &&
226+
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ &&
227+
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
228+
reference_matmul_f8e5m2fnuz_f8e5m2fnuz_f32_f32(
229+
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
230+
(const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
231+
(const float*)acc_data, (float*)result_data, m, n);
232+
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ &&
233+
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ &&
234+
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
235+
reference_matmul_f8e4m3fnuz_f8e4m3fnuz_f32_f32(
236+
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
237+
(const uint8_t*)lhs_data, (const uint8_t*)rhs_data,
238+
(const float*)acc_data, (float*)result_data, m, n);
188239
} else {
189240
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
190241
"unhandled combination of element types in matmul");

tools/testing/e2e/test_utils.c

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,36 @@ iree_test_utils_e2e_value_t iree_test_utils_value_make_i32(int32_t value) {
9393
return result;
9494
}
9595

96+
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2(uint8_t value) {
97+
iree_test_utils_e2e_value_t result;
98+
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2;
99+
result.f8_u8 = value;
100+
return result;
101+
}
102+
103+
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3(uint8_t value) {
104+
iree_test_utils_e2e_value_t result;
105+
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3;
106+
result.f8_u8 = value;
107+
return result;
108+
}
109+
110+
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2FNUZ(
111+
uint16_t value) {
112+
iree_test_utils_e2e_value_t result;
113+
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ;
114+
result.f8_u8 = value;
115+
return result;
116+
}
117+
118+
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3FNUZ(
119+
uint16_t value) {
120+
iree_test_utils_e2e_value_t result;
121+
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ;
122+
result.f8_u8 = value;
123+
return result;
124+
}
125+
96126
iree_test_utils_e2e_value_t iree_test_utils_value_make_f16(uint16_t value) {
97127
iree_test_utils_e2e_value_t result;
98128
result.type = IREE_TEST_UTILS_VALUE_TYPE_F16;
@@ -123,6 +153,14 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
123153
return iree_test_utils_value_make_i16(((int16_t*)data)[index]);
124154
} else if (iree_hal_element_type_is_integer(result_type, 32)) {
125155
return iree_test_utils_value_make_i32(((int32_t*)data)[index]);
156+
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) {
157+
return iree_test_utils_value_make_f8E5M2(((uint8_t*)data)[index]);
158+
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) {
159+
return iree_test_utils_value_make_f8E4M3(((uint8_t*)data)[index]);
160+
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) {
161+
return iree_test_utils_value_make_f8E5M2FNUZ(((uint8_t*)data)[index]);
162+
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) {
163+
return iree_test_utils_value_make_f8E4M3FNUZ(((uint8_t*)data)[index]);
126164
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
127165
return iree_test_utils_value_make_f16(((uint16_t*)data)[index]);
128166
} else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
@@ -147,6 +185,22 @@ int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
147185
return snprintf(buf, bufsize, "%" PRIi32, value.i32);
148186
case IREE_TEST_UTILS_VALUE_TYPE_I64:
149187
return snprintf(buf, bufsize, "%" PRIi64, value.i64);
188+
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
189+
return snprintf(buf, bufsize,
190+
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
191+
iree_math_f8e5m2_to_f32(value.f8_u8));
192+
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
193+
return snprintf(buf, bufsize,
194+
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
195+
iree_math_f8e4m3_to_f32(value.f8_u8));
196+
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
197+
return snprintf(buf, bufsize,
198+
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
199+
iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
200+
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
201+
return snprintf(buf, bufsize,
202+
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
203+
iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
150204
case IREE_TEST_UTILS_VALUE_TYPE_F16:
151205
return snprintf(buf, bufsize,
152206
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
@@ -257,6 +311,18 @@ void iree_test_utils_write_element(iree_hal_element_type_t element_type,
257311
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
258312
*(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
259313
break;
314+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
315+
*(uint8_t*)dst = iree_math_f32_to_f8e5m2((float)value);
316+
break;
317+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
318+
*(uint8_t*)dst = iree_math_f32_to_f8e4m3((float)value);
319+
break;
320+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
321+
*(uint8_t*)dst = iree_math_f32_to_f8e5m2fnuz((float)value);
322+
break;
323+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
324+
*(uint8_t*)dst = iree_math_f32_to_f8e4m3fnuz((float)value);
325+
break;
260326
WRITE_ELEMENT_CASE(FLOAT_32, float)
261327
WRITE_ELEMENT_CASE(FLOAT_64, double)
262328
// clang-format on
@@ -296,6 +362,10 @@ void iree_test_utils_get_min_max_for_element_type(
296362
*max = +4;
297363
break;
298364
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
365+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
366+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
367+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
368+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
299369
*min = -2;
300370
*max = +2;
301371
break;

tools/testing/e2e/test_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ typedef enum iree_test_utils_value_type_e {
4848
IREE_TEST_UTILS_VALUE_TYPE_F64 = 7,
4949
// bfloat16
5050
IREE_TEST_UTILS_VALUE_TYPE_BF16 = 8,
51+
// 8-bit float types.
52+
IREE_TEST_UTILS_VALUE_TYPE_F8E5M2 = 9,
53+
IREE_TEST_UTILS_VALUE_TYPE_F8E4M3 = 10,
54+
IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ = 11,
55+
IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ = 12,
5156
} iree_test_utils_value_type_t;
5257

5358
// Maximum size, in bytes, of any value type we can represent.
@@ -64,6 +69,7 @@ typedef struct iree_test_utils_value_t {
6469
float f32;
6570
uint16_t f16_u16;
6671
uint16_t bf16_u16;
72+
uint8_t f8_u8;
6773
double f64;
6874
uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE]; // max size of all
6975
// value types

0 commit comments

Comments
 (0)