diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 037c0582bbbf8..c2a4f4db43227 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -79,7 +79,7 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) { } // Total dot product error -static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) { +static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2, const int nrc) { GGML_UNUSED(qfns); std::vector tmp_q1(2*test_size); @@ -91,7 +91,7 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr vdot->from_float(test_data2, tmp_q2.data(), test_size); float result = INFINITY; - qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1); + qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, nrc); const float dot_ref = dot_product(test_data1, test_data2, test_size); @@ -163,7 +163,7 @@ int main(int argc, char * argv[]) { printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error); } - const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data()); + const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 1); const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT @@ -175,6 +175,16 @@ int main(int argc, char * argv[]) { if (failed || verbose) { printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); } + + // Test nrc=2 path for types that support it + if (qfns_cpu->nrows == 2) { + const float vec_dot_error_nrc2 = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 2); + failed = !(vec_dot_error_nrc2 < max_allowed_error); + num_failed += failed; + if (failed || verbose) { + printf("%5s dot product error (nrc=2): %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error_nrc2); + } + } } }