Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions tests/test-quantize-fns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) {
}

// Total dot product error
static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {
static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2, const int nrc) {
GGML_UNUSED(qfns);

std::vector<uint8_t> tmp_q1(2*test_size);
Expand All @@ -91,7 +91,7 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr
vdot->from_float(test_data2, tmp_q2.data(), test_size);

float result = INFINITY;
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, nrc);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to expand the input data depending on nrc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Can you tell me an appropriate test_size for nrc=2 ? Would it just be double?


const float dot_ref = dot_product(test_data1, test_data2, test_size);

Expand Down Expand Up @@ -163,7 +163,7 @@ int main(int argc, char * argv[]) {
printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
}

const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 1);
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
? MAX_DOT_PRODUCT_ERROR_LOWBIT
Expand All @@ -175,6 +175,16 @@ int main(int argc, char * argv[]) {
if (failed || verbose) {
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
}

// Test nrc=2 path for types that support it
if (qfns_cpu->nrows == 2) {
const float vec_dot_error_nrc2 = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 2);
failed = !(vec_dot_error_nrc2 < max_allowed_error);
num_failed += failed;
if (failed || verbose) {
printf("%5s dot product error (nrc=2): %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error_nrc2);
}
}
}
}

Expand Down
Loading