@@ -142,7 +142,7 @@ static bool tensor_is_contiguous(const struct ggml_tensor * tensor) {
142142}
143143
144144static void test_roundtrip_on_chunk (
145- const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits & qfns, bool use_reference,
145+ const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference,
146146 float * input_scratch, char * quantized_scratch, float * output_scratch, error_stats & stats
147147) {
148148 if (layer->type == GGML_TYPE_F16) {
@@ -156,7 +156,7 @@ static void test_roundtrip_on_chunk(
156156 if (use_reference) {
157157 qfns.from_float_ref (input_scratch, quantized_scratch, chunk_size);
158158 } else {
159- qfns .from_float (input_scratch, quantized_scratch, chunk_size);
159+ qfns_cpu .from_float (input_scratch, quantized_scratch, chunk_size);
160160 }
161161 qfns.to_float (quantized_scratch, output_scratch, chunk_size);
162162
@@ -166,7 +166,7 @@ static void test_roundtrip_on_chunk(
166166
167167// Run quantization function for a single layer and update error stats
168168static void test_roundtrip_on_layer (
169- std::string & name, bool print_layer_stats, const ggml_type_traits & qfns, bool use_reference,
169+ std::string & name, bool print_layer_stats, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference,
170170 const ggml_tensor * layer, std::vector<float > & input_scratch, std::vector<char > & quantized_scratch,
171171 std::vector<float > & output_scratch, error_stats & total_error, int max_thread = 0
172172) {
@@ -187,13 +187,13 @@ static void test_roundtrip_on_layer(
187187 int num_chunks = (nelements + chunk_size - 1 )/chunk_size;
188188
189189 if (num_chunks < 2 || max_thread < 2 ) {
190- test_roundtrip_on_chunk (layer, 0 , nelements, qfns, use_reference, input_scratch_ptr, quantized_scratch.data (),
190+ test_roundtrip_on_chunk (layer, 0 , nelements, qfns, qfns_cpu, use_reference, input_scratch_ptr, quantized_scratch.data (),
191191 output_scratch.data (), print_layer_stats ? layer_error : total_error);
192192 } else {
193193 auto & stats = print_layer_stats ? layer_error : total_error;
194194 std::mutex mutex;
195195 uint64_t counter = 0 ;
196- auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr,
196+ auto compute = [&mutex, &counter, &stats, &qfns, &qfns_cpu, nelements, layer, use_reference, input_scratch_ptr,
197197 &quantized_scratch, &output_scratch, chunk_size] () {
198198 error_stats local_stats {};
199199 while (true ) {
@@ -205,7 +205,7 @@ static void test_roundtrip_on_layer(
205205 }
206206 lock.unlock ();
207207 uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
208- test_roundtrip_on_chunk (layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
208+ test_roundtrip_on_chunk (layer, offset, chunk, qfns, qfns_cpu, use_reference, input_scratch_ptr + offset,
209209 quantized_scratch.data () + 4 *offset, output_scratch.data () + offset, local_stats);
210210 }
211211 };
@@ -371,8 +371,9 @@ int main(int argc, char ** argv) {
371371 if (!params.include_types .empty () && std::find (params.include_types .begin (), params.include_types .end (), i) == params.include_types .end ()) {
372372 continue ;
373373 }
374- const auto * qfns = ggml_get_type_traits (type);
375- if (qfns->from_float && qfns->to_float ) {
374+ const auto * qfns = ggml_get_type_traits (type);
375+ const auto * qfns_cpu = ggml_get_type_traits_cpu (type);
376+ if (qfns_cpu->from_float && qfns->to_float ) {
376377 if (params.verbose ) {
377378 printf (" testing %s ...\n " , ggml_type_name (type));
378379 }
@@ -393,7 +394,7 @@ int main(int argc, char ** argv) {
393394 test_roundtrip_on_layer (
394395 layer_name,
395396 params.per_layer_stats ,
396- *qfns,
397+ *qfns, *qfns_cpu,
397398 params.reference ,
398399 kv_tensor.second ,
399400 input_scratch,
0 commit comments