Skip to content

Commit 01569e2

Browse files
bjacobGroverkss
authored andcommitted
e2e matmul test improvements (iree-org#19016)
Working on iree-org#18980 let me spend quality time with e2e matmul tests and suggested some changes. The main change is to simplify the printing of numerical values to always use high precision, meaning print all significant digits of floating point values. Since our tests generate small integral values and the intent is generally to be testing mostly the exact arithmetic that happens on small integral values, in most cases this doesn't make any difference. But I found that RDNA3 float arithmetic produces non-exact results even on those values. As a result, I got values like 1+epsilon where 1 was expected, causing a test to fail (since we didn't know we needed to opt out from requiring exact results) and the test output cryptically printed both values as "1". The other change is to more consistently print the same number of rows and columns regardless of whether we are at the start or in the middle of a dimension, and to have that number be what we call "context" (before, it was "2 * context"). Also a seasonal emoji change. Signed-off-by: Benoit Jacob <[email protected]>
1 parent d3b9a1b commit 01569e2

File tree

3 files changed

+49
-67
lines changed

3 files changed

+49
-67
lines changed

tools/testing/e2e/iree-e2e-matmul-test.cc

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,9 @@ static void matmul_results_deinitialize(matmul_results_t* results) {
411411
}
412412

413413
// Returns the largest number of characters to print any matrix element.
414-
static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
415-
iree_hal_dim_t row_start, iree_hal_dim_t row_end,
416-
iree_hal_dim_t cols, iree_hal_dim_t col_start,
417-
iree_hal_dim_t col_end,
414+
static int get_max_elem_width(iree_hal_dim_t rows, iree_hal_dim_t row_start,
415+
iree_hal_dim_t row_end, iree_hal_dim_t cols,
416+
iree_hal_dim_t col_start, iree_hal_dim_t col_end,
418417
iree_hal_element_type_t element_type,
419418
const uint8_t* matrix) {
420419
int max_elem_width = 0;
@@ -426,15 +425,14 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
426425
// NOTE: iree_max is a macro and may evaluate its args twice.
427426
char buf[64];
428427
int this_elem_width =
429-
iree_test_utils_snprintf_value(buf, sizeof(buf), elem, precision);
428+
iree_test_utils_snprintf_value(buf, sizeof(buf), elem);
430429
max_elem_width = iree_max(max_elem_width, this_elem_width);
431430
}
432431
}
433432
return max_elem_width;
434433
}
435434

436435
// Prints |matrix| to |file|, with |label| as caption.
437-
// |precision| controls how many decimals are printed for float values.
438436
//
439437
// If |other_matrix| is not NULL, then any matrix entries that disagree
440438
// between |matrix| and |other_matrix| (according to
@@ -451,22 +449,21 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
451449
// characters. According to
452450
// https://www.unicode.org/reports/tr11/#Recommendations, a single emoji
453451
// character should meet that requirement.
454-
static void print_matrix(FILE* file, const char* label, precision_t precision,
455-
iree_hal_dim_t rows, iree_hal_dim_t row_start,
456-
iree_hal_dim_t row_end, iree_hal_dim_t cols,
457-
iree_hal_dim_t col_start, iree_hal_dim_t col_end,
452+
static void print_matrix(FILE* file, const char* label, iree_hal_dim_t rows,
453+
iree_hal_dim_t row_start, iree_hal_dim_t row_end,
454+
iree_hal_dim_t cols, iree_hal_dim_t col_start,
455+
iree_hal_dim_t col_end,
458456
iree_hal_element_type_t element_type,
459457
const uint8_t* matrix, const uint8_t* other_matrix,
460458
const char* highlight) {
461459
IREE_ASSERT((other_matrix == NULL) == (highlight == NULL));
462-
int max_elem_width =
463-
get_max_elem_width(precision, rows, row_start, row_end, cols, col_start,
464-
col_end, element_type, matrix);
460+
int max_elem_width = get_max_elem_width(
461+
rows, row_start, row_end, cols, col_start, col_end, element_type, matrix);
465462
if (other_matrix) {
466463
// NOTE: iree_max is a macro and may evaluate its args twice.
467464
int other_matrix_max_elem_width =
468-
get_max_elem_width(precision, rows, row_start, row_end, cols, col_start,
469-
col_end, element_type, other_matrix);
465+
get_max_elem_width(rows, row_start, row_end, cols, col_start, col_end,
466+
element_type, other_matrix);
470467
max_elem_width = iree_max(max_elem_width, other_matrix_max_elem_width);
471468
}
472469

@@ -489,7 +486,7 @@ static void print_matrix(FILE* file, const char* label, precision_t precision,
489486
!iree_test_utils_result_elements_agree(element, other_element);
490487
}
491488
char buf[64];
492-
iree_test_utils_snprintf_value(buf, sizeof(buf), element, precision);
489+
iree_test_utils_snprintf_value(buf, sizeof(buf), element);
493490
fprintf(file, "%*s", max_elem_width, buf);
494491
// See comment on |highlight| function parameter for why 2 spaces.
495492
// A 3rd space is added unconditionally to make it clear that a highlight
@@ -523,13 +520,13 @@ static iree_status_t check_matmul_failure(
523520
char actual_value_buf[32];
524521
char expected_value_buf[32];
525522
iree_test_utils_snprintf_value(actual_value_buf, sizeof(actual_value_buf),
526-
actual_value, PRECISION_HIGH);
523+
actual_value);
527524
iree_test_utils_snprintf_value(expected_value_buf, sizeof(expected_value_buf),
528-
expected_value, PRECISION_HIGH);
525+
expected_value);
529526
fprintf(file, "actual value: %s\n", actual_value_buf);
530527
fprintf(file, "expected value: %s\n", expected_value_buf);
531528

532-
iree_hal_dim_t context = 8;
529+
iree_hal_dim_t context = 16;
533530
const char* context_env = getenv("IREE_MATMUL_TEST_SHOW_CONTEXT");
534531
if (context_env) {
535532
if (1 != sscanf(context_env, "%" PRIdim, &context)) {
@@ -540,39 +537,36 @@ static iree_status_t check_matmul_failure(
540537
}
541538
}
542539
iree_hal_dim_t m_start =
543-
(iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context);
544-
iree_hal_dim_t m_end = iree_min(results->m, row + context);
540+
(iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context / 2);
541+
iree_hal_dim_t m_end = iree_min(results->m, m_start + context);
545542
iree_hal_dim_t n_start =
546-
(iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context);
547-
iree_hal_dim_t n_end = iree_min(results->n, col + context);
543+
(iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context / 2);
544+
iree_hal_dim_t n_end = iree_min(results->n, n_start + context);
548545
iree_hal_dim_t k_start = 0;
549-
iree_hal_dim_t k_end = iree_min(results->k, 2 * context);
550-
// [k_start, k_end) could be arbitrarily long at this point. Constrain it a
551-
// bit to avoid huge output.
552-
k_end = iree_min(k_end, k_start + 4 * context);
546+
iree_hal_dim_t k_end = iree_min(results->k, context);
553547

554548
fprintf(file, "\n");
555-
print_matrix(file, "left-hand side", PRECISION_LOW, results->m, m_start,
556-
m_end, results->k, k_start, k_end, results->lhs_type,
557-
results->lhs_contents.data, NULL, NULL);
549+
print_matrix(file, "left-hand side", results->m, m_start, m_end, results->k,
550+
k_start, k_end, results->lhs_type, results->lhs_contents.data,
551+
NULL, NULL);
558552
fprintf(file, "\n");
559-
print_matrix(file, "right-hand side", PRECISION_LOW, results->k, k_start,
560-
k_end, results->n, n_start, n_end, results->rhs_type,
561-
results->rhs_contents.data, NULL, NULL);
553+
print_matrix(file, "right-hand side", results->k, k_start, k_end, results->n,
554+
n_start, n_end, results->rhs_type, results->rhs_contents.data,
555+
NULL, NULL);
562556
fprintf(file, "\n");
563557
if (results->acc_contents.data) {
564-
print_matrix(file, "input accumulator", PRECISION_LOW, results->m, m_start,
565-
m_end, results->n, n_start, n_end, results->acc_type,
558+
print_matrix(file, "input accumulator", results->m, m_start, m_end,
559+
results->n, n_start, n_end, results->acc_type,
566560
results->acc_contents.data, NULL, NULL);
567561
fprintf(file, "\n");
568562
}
569-
print_matrix(file, "expected result", PRECISION_LOW, results->m, m_start,
570-
m_end, results->n, n_start, n_end, results->result_type,
563+
print_matrix(file, "expected result", results->m, m_start, m_end, results->n,
564+
n_start, n_end, results->result_type,
571565
results->expected_contents.data, results->actual_contents.data,
572566
iree_test_utils_emoji(true));
573567
fprintf(file, "\n");
574-
print_matrix(file, "actual result", PRECISION_LOW, results->m, m_start, m_end,
575-
results->n, n_start, n_end, results->result_type,
568+
print_matrix(file, "actual result", results->m, m_start, m_end, results->n,
569+
n_start, n_end, results->result_type,
576570
results->actual_contents.data, results->expected_contents.data,
577571
iree_test_utils_emoji(false));
578572
fprintf(file, "\n");

tools/testing/e2e/test_utils.c

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ int32_t iree_test_utils_max_elements_to_check(void) {
5050
return FLAG_max_elements_to_check;
5151
}
5252

53-
const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🐞"; }
53+
const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🎃"; }
5454

5555
int iree_test_utils_calculate_check_every(iree_hal_dim_t tot_elements,
5656
iree_hal_dim_t no_div_of) {
@@ -182,9 +182,13 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
182182
return iree_test_utils_value_make_none();
183183
}
184184

185+
// Important: print all floating point values to FULL precision.
186+
// The audience is debugging low-level numerical bugs.
187+
// Since the values used in most tests are small and integral, these will
188+
// normally print just as concisely, while the extra precision requested here
189+
// will only kick in when it's needed, when there is a numerical bug.
185190
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
186-
iree_test_utils_e2e_value_t value,
187-
precision_t precision) {
191+
iree_test_utils_e2e_value_t value) {
188192
switch (value.type) {
189193
case IREE_TEST_UTILS_VALUE_TYPE_I8:
190194
return snprintf(buf, bufsize, "%" PRIi8, value.i8);
@@ -195,36 +199,27 @@ int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
195199
case IREE_TEST_UTILS_VALUE_TYPE_I64:
196200
return snprintf(buf, bufsize, "%" PRIi64, value.i64);
197201
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
198-
return snprintf(buf, bufsize,
199-
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
202+
return snprintf(buf, bufsize, "%.3g",
200203
iree_math_f8e5m2_to_f32(value.f8_u8));
201204
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
202-
return snprintf(buf, bufsize,
203-
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
205+
return snprintf(buf, bufsize, "%.3g",
204206
iree_math_f8e4m3_to_f32(value.f8_u8));
205207
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
206-
return snprintf(buf, bufsize,
207-
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
208+
return snprintf(buf, bufsize, "%.3g",
208209
iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
209210
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
210-
return snprintf(buf, bufsize,
211-
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
211+
return snprintf(buf, bufsize, "%.3g",
212212
iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
213213
case IREE_TEST_UTILS_VALUE_TYPE_F16:
214-
return snprintf(buf, bufsize,
215-
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
214+
return snprintf(buf, bufsize, "%.5g",
216215
iree_math_f16_to_f32(value.f16_u16));
217216
case IREE_TEST_UTILS_VALUE_TYPE_BF16:
218-
return snprintf(buf, bufsize,
219-
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
217+
return snprintf(buf, bufsize, "%.5g",
220218
iree_math_bf16_to_f32(value.bf16_u16));
221219
case IREE_TEST_UTILS_VALUE_TYPE_F32:
222-
return snprintf(buf, bufsize,
223-
precision == PRECISION_HIGH ? "%.8g" : "%.4g", value.f32);
220+
return snprintf(buf, bufsize, "%.8g", value.f32);
224221
case IREE_TEST_UTILS_VALUE_TYPE_F64:
225-
return snprintf(buf, bufsize,
226-
precision == PRECISION_HIGH ? "%.16g" : "%.4g",
227-
value.f64);
222+
return snprintf(buf, bufsize, "%.16g", value.f64);
228223
default:
229224
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
230225
"unhandled value type"));

tools/testing/e2e/test_utils.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,6 @@ typedef struct iree_test_utils_value_t {
7676
};
7777
} iree_test_utils_e2e_value_t;
7878

79-
// Enum controlling how many decimals to print floats with.
80-
typedef enum iree_test_utils_precision_e {
81-
PRECISION_LOW,
82-
PRECISION_HIGH,
83-
} precision_t;
84-
8579
// Reads an element from a buffer given index.
8680
iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
8781
iree_hal_dim_t index, iree_hal_element_type_t result_type,
@@ -90,8 +84,7 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
9084
// Prints a iree_e2e_test_value_t to a string buffer. Returns the number of
9185
// characters written. Like snprintf.
9286
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
93-
iree_test_utils_e2e_value_t value,
94-
precision_t precision);
87+
iree_test_utils_e2e_value_t value);
9588

9689
// Returns true if |expected| and |actual| agree to tolerable accuracy.
9790
bool iree_test_utils_result_elements_agree(iree_test_utils_e2e_value_t expected,

0 commit comments

Comments
 (0)