@@ -287,40 +287,37 @@ static void ggml_compute_forward_dup_to_q(
287287 const int ir0 = dr * ith;
288288 const int ir1 = MIN (ir0 + dr, nr);
289289
290- if (ggml_is_contiguous (dst)) {
291- if (nb00 == sizeof (src_t )) {
292- if (ggml_get_type_traits_cpu (dst->type )->from_float ) {
293- // casting non-quantized types --> intermediate f32 --> quantized
294- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu (dst->type )->from_float ;
295- float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
296-
297- size_t id = 0 ;
298- size_t rs = nb0 * (ne00 / ggml_blck_size (dst->type ));
299- char * dst_ptr = (char *) dst->data ;
290+ if (ggml_is_contiguous (dst) &&
291+ nb00 == sizeof (src_t ) &&
292+ ggml_get_type_traits_cpu (dst->type )->from_float ) {
293+ // casting non-quantized types --> intermediate f32 --> quantized
294+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu (dst->type )->from_float ;
295+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
300296
301- for (int i03 = 0 ; i03 < ne03; i03++) {
302- for (int i02 = 0 ; i02 < ne02; i02++) {
303- id += rs * ir0;
304- for (int i01 = ir0; i01 < ir1; i01++) {
305- const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
297+ size_t id = 0 ;
298+ size_t rs = nb0 * (ne00 / ggml_blck_size (dst->type ));
299+ char * dst_ptr = (char *) dst->data ;
306300
307- for (int i00 = 0 ; i00 < ne00; i00++) {
308- src0_f32[i00] = type_conversion_table<src_t >::to_f32 (src0_ptr[i00]);
309- }
301+ for (int i03 = 0 ; i03 < ne03; i03++) {
302+ for (int i02 = 0 ; i02 < ne02; i02++) {
303+ id += rs * ir0;
304+ for (int i01 = ir0; i01 < ir1; i01++) {
305+ const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
310306
311- quantize_row_q (src0_f32, dst_ptr + id, ne00);
312- id += rs;
313- }
314- id += rs * (ne01 - ir1);
307+ for (int i00 = 0 ; i00 < ne00; i00++) {
308+ src0_f32[i00] = type_conversion_table<src_t >::to_f32 (src0_ptr[i00]);
315309 }
310+
311+ quantize_row_q (src0_f32, dst_ptr + id, ne00);
312+ id += rs;
316313 }
317- return ;
314+ id += rs * (ne01 - ir1) ;
318315 }
319- } // TODO: else
316+ }
317+ } else {
318+ // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
319+ GGML_ABORT (" not implemented" );
320320 }
321-
322- // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
323- GGML_ABORT (" not implemented" );
324321}
325322
326323// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
@@ -560,7 +557,7 @@ void ggml_compute_forward_dup(
560557 } break ;
561558 case GGML_TYPE_I32:
562559 {
563- /* */ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t , float >(params, dst);
560+ if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t , float >(params, dst);
564561 else GGML_ABORT (" not implemented" );
565562 } break ;
566563 default :
0 commit comments