Skip to content

Commit 072f3d8

Browse files
Test for nrc=2 as well | i8mm kernels
1 parent 0889589 commit 072f3d8

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tests/test-quantize-fns.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) {
7979
}
8080

8181
// Total dot product error
82-
static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {
82+
static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2, const int nrc) {
8383
GGML_UNUSED(qfns);
8484

8585
std::vector<uint8_t> tmp_q1(2*test_size);
@@ -91,7 +91,7 @@ static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_tr
9191
vdot->from_float(test_data2, tmp_q2.data(), test_size);
9292

9393
float result = INFINITY;
94-
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
94+
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, nrc);
9595

9696
const float dot_ref = dot_product(test_data1, test_data2, test_size);
9797

@@ -163,7 +163,7 @@ int main(int argc, char * argv[]) {
163163
printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
164164
}
165165

166-
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
166+
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 1);
167167
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
168168
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
169169
? MAX_DOT_PRODUCT_ERROR_LOWBIT
@@ -175,6 +175,17 @@ int main(int argc, char * argv[]) {
175175
if (failed || verbose) {
176176
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
177177
}
178+
179+
// Test i8mm path (nrc=2) for supported types
180+
if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q8_0 ||
181+
type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q6_K) {
182+
const float vec_dot_error_i8mm = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data(), 2);
183+
failed = !(vec_dot_error_i8mm < max_allowed_error);
184+
num_failed += failed;
185+
if (failed || verbose) {
186+
printf("%5s dot product error (i8mm): %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error_i8mm);
187+
}
188+
}
178189
}
179190
}
180191

0 commit comments

Comments
 (0)