@@ -411,10 +411,9 @@ static void matmul_results_deinitialize(matmul_results_t* results) {
411411}
412412
413413// Returns the largest number of characters to print any matrix element.
414- static int get_max_elem_width (precision_t precision, iree_hal_dim_t rows,
415- iree_hal_dim_t row_start, iree_hal_dim_t row_end,
416- iree_hal_dim_t cols, iree_hal_dim_t col_start,
417- iree_hal_dim_t col_end,
414+ static int get_max_elem_width (iree_hal_dim_t rows, iree_hal_dim_t row_start,
415+ iree_hal_dim_t row_end, iree_hal_dim_t cols,
416+ iree_hal_dim_t col_start, iree_hal_dim_t col_end,
418417 iree_hal_element_type_t element_type,
419418 const uint8_t * matrix) {
420419 int max_elem_width = 0 ;
@@ -426,15 +425,14 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
426425 // NOTE: iree_max is a macro and may evaluate its args twice.
427426 char buf[64 ];
428427 int this_elem_width =
429- iree_test_utils_snprintf_value (buf, sizeof (buf), elem, precision );
428+ iree_test_utils_snprintf_value (buf, sizeof (buf), elem);
430429 max_elem_width = iree_max (max_elem_width, this_elem_width);
431430 }
432431 }
433432 return max_elem_width;
434433}
435434
436435// Prints |matrix| to |file|, with |label| as caption.
437- // |precision| controls how many decimals are printed for float values.
438436//
439437// If |other_matrix| is not NULL, then any matrix entries that disagree
440438// between |matrix| and |other_matrix| (according to
@@ -451,22 +449,21 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
451449// characters. According to
452450// https://www.unicode.org/reports/tr11/#Recommendations, a single emoji
453451// character should meet that requirement.
454- static void print_matrix (FILE* file, const char * label, precision_t precision ,
455- iree_hal_dim_t rows , iree_hal_dim_t row_start ,
456- iree_hal_dim_t row_end , iree_hal_dim_t cols ,
457- iree_hal_dim_t col_start, iree_hal_dim_t col_end,
452+ static void print_matrix (FILE* file, const char * label, iree_hal_dim_t rows ,
453+ iree_hal_dim_t row_start , iree_hal_dim_t row_end ,
454+ iree_hal_dim_t cols , iree_hal_dim_t col_start ,
455+ iree_hal_dim_t col_end,
458456 iree_hal_element_type_t element_type,
459457 const uint8_t * matrix, const uint8_t * other_matrix,
460458 const char * highlight) {
461459 IREE_ASSERT ((other_matrix == NULL ) == (highlight == NULL ));
462- int max_elem_width =
463- get_max_elem_width (precision, rows, row_start, row_end, cols, col_start,
464- col_end, element_type, matrix);
460+ int max_elem_width = get_max_elem_width (
461+ rows, row_start, row_end, cols, col_start, col_end, element_type, matrix);
465462 if (other_matrix) {
466463 // NOTE: iree_max is a macro and may evaluate its args twice.
467464 int other_matrix_max_elem_width =
468- get_max_elem_width (precision, rows, row_start, row_end, cols, col_start,
469- col_end, element_type, other_matrix);
465+ get_max_elem_width (rows, row_start, row_end, cols, col_start, col_end ,
466+ element_type, other_matrix);
470467 max_elem_width = iree_max (max_elem_width, other_matrix_max_elem_width);
471468 }
472469
@@ -489,7 +486,7 @@ static void print_matrix(FILE* file, const char* label, precision_t precision,
489486 !iree_test_utils_result_elements_agree (element, other_element);
490487 }
491488 char buf[64 ];
492- iree_test_utils_snprintf_value (buf, sizeof (buf), element, precision );
489+ iree_test_utils_snprintf_value (buf, sizeof (buf), element);
493490 fprintf (file, " %*s" , max_elem_width, buf);
494491 // See comment on |highlight| function parameter for why 2 spaces.
495492 // A 3rd space is added unconditionally to make it clear that a highlight
@@ -523,13 +520,13 @@ static iree_status_t check_matmul_failure(
523520 char actual_value_buf[32 ];
524521 char expected_value_buf[32 ];
525522 iree_test_utils_snprintf_value (actual_value_buf, sizeof (actual_value_buf),
526- actual_value, PRECISION_HIGH );
523+ actual_value);
527524 iree_test_utils_snprintf_value (expected_value_buf, sizeof (expected_value_buf),
528- expected_value, PRECISION_HIGH );
525+ expected_value);
529526 fprintf (file, " actual value: %s\n " , actual_value_buf);
530527 fprintf (file, " expected value: %s\n " , expected_value_buf);
531528
532- iree_hal_dim_t context = 8 ;
529+ iree_hal_dim_t context = 16 ;
533530 const char * context_env = getenv (" IREE_MATMUL_TEST_SHOW_CONTEXT" );
534531 if (context_env) {
535532 if (1 != sscanf (context_env, " %" PRIdim, &context)) {
@@ -540,39 +537,36 @@ static iree_status_t check_matmul_failure(
540537 }
541538 }
542539 iree_hal_dim_t m_start =
543- (iree_hal_dim_t )iree_max (0 , (int64_t )row - (int64_t )context);
544- iree_hal_dim_t m_end = iree_min (results->m , row + context);
540+ (iree_hal_dim_t )iree_max (0 , (int64_t )row - (int64_t )context / 2 );
541+ iree_hal_dim_t m_end = iree_min (results->m , m_start + context);
545542 iree_hal_dim_t n_start =
546- (iree_hal_dim_t )iree_max (0 , (int64_t )col - (int64_t )context);
547- iree_hal_dim_t n_end = iree_min (results->n , col + context);
543+ (iree_hal_dim_t )iree_max (0 , (int64_t )col - (int64_t )context / 2 );
544+ iree_hal_dim_t n_end = iree_min (results->n , n_start + context);
548545 iree_hal_dim_t k_start = 0 ;
549- iree_hal_dim_t k_end = iree_min (results->k , 2 * context);
550- // [k_start, k_end) could be arbitrarily long at this point. Constrain it a
551- // bit to avoid huge output.
552- k_end = iree_min (k_end, k_start + 4 * context);
546+ iree_hal_dim_t k_end = iree_min (results->k , context);
553547
554548 fprintf (file, " \n " );
555- print_matrix (file, " left-hand side" , PRECISION_LOW, results->m , m_start,
556- m_end, results-> k , k_start, k_end, results->lhs_type ,
557- results-> lhs_contents . data , NULL , NULL );
549+ print_matrix (file, " left-hand side" , results->m , m_start, m_end, results-> k ,
550+ k_start, k_end, results->lhs_type , results-> lhs_contents . data ,
551+ NULL , NULL );
558552 fprintf (file, " \n " );
559- print_matrix (file, " right-hand side" , PRECISION_LOW, results->k , k_start,
560- k_end, results-> n , n_start, n_end, results->rhs_type ,
561- results-> rhs_contents . data , NULL , NULL );
553+ print_matrix (file, " right-hand side" , results->k , k_start, k_end, results-> n ,
554+ n_start, n_end, results->rhs_type , results-> rhs_contents . data ,
555+ NULL , NULL );
562556 fprintf (file, " \n " );
563557 if (results->acc_contents .data ) {
564- print_matrix (file, " input accumulator" , PRECISION_LOW, results->m , m_start,
565- m_end, results->n , n_start, n_end, results->acc_type ,
558+ print_matrix (file, " input accumulator" , results->m , m_start, m_end ,
559+ results->n , n_start, n_end, results->acc_type ,
566560 results->acc_contents .data , NULL , NULL );
567561 fprintf (file, " \n " );
568562 }
569- print_matrix (file, " expected result" , PRECISION_LOW, results->m , m_start,
570- m_end, results-> n , n_start, n_end, results->result_type ,
563+ print_matrix (file, " expected result" , results->m , m_start, m_end, results-> n ,
564+ n_start, n_end, results->result_type ,
571565 results->expected_contents .data , results->actual_contents .data ,
572566 iree_test_utils_emoji (true ));
573567 fprintf (file, " \n " );
574- print_matrix (file, " actual result" , PRECISION_LOW, results->m , m_start, m_end,
575- results-> n , n_start, n_end, results->result_type ,
568+ print_matrix (file, " actual result" , results->m , m_start, m_end, results-> n ,
569+ n_start, n_end, results->result_type ,
576570 results->actual_contents .data , results->expected_contents .data ,
577571 iree_test_utils_emoji (false ));
578572 fprintf (file, " \n " );
0 commit comments