Skip to content

Commit 1860cf9

Browse files
authored
deduplicate copy functions
1 parent 4162ffe commit 1860cf9

File tree

3 files changed

+38
-122
lines changed

3 files changed

+38
-122
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 11 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,17 @@
22

33
#include "ggml-common.h"
44

5-
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
6-
*dst = *src;
7-
}
8-
9-
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
10-
*dst = __float2half(*src);
11-
}
12-
13-
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
14-
*dst = *src;
15-
}
16-
17-
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
18-
*dst = *src;
19-
}
20-
21-
static __device__ __forceinline__ void convert_f16_bf16(const half * src, nv_bfloat16 * dst) {
5+
template<typename src_t, typename dst_t>
6+
static __device__ __forceinline__ void convert_to_flt(const src_t * src, dst_t * dst) {
227
*dst = float(*src);
238
}
249

25-
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
26-
*dst = *src;
27-
}
28-
29-
static __device__ __forceinline__ void convert_bf16_f16(const nv_bfloat16 * src, half * dst) {
10+
template<typename src_t>
11+
static __device__ __forceinline__ void convert_to_f16(const src_t * src, half * dst) {
3012
*dst = __float2half(*src);
3113
}
3214

33-
static __device__ __forceinline__ void convert_bf16_f32(const nv_bfloat16 * src, float * dst) {
15+
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
3416
*dst = *src;
3517
}
3618

@@ -242,34 +224,16 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
242224
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
243225
}
244226

245-
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
246-
convert_f32_f32((const float *)cxi, (float *)cdsti);
227+
template<typename src_t, typename dst_t>
228+
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
229+
convert_to_flt((const src_t *)cxi, (dst_t *)cdsti);
247230
}
248231

249-
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
250-
convert_f32_f16((const float *)cxi, (half *)cdsti);
251-
}
252-
253-
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
254-
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
232+
template<typename src_t>
233+
static __device__ void cpy_1_to_f16(const char * cxi, char * cdsti) {
234+
convert_to_f16((const src_t *)cxi, (half *)cdsti);
255235
}
256236

257237
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
258238
convert_f16_f16((const half *)cxi, (half *)cdsti);
259239
}
260-
261-
static __device__ void cpy_1_f16_bf16(const char * cxi, char * cdsti) {
262-
convert_f16_bf16((const half *)cxi, (nv_bfloat16 *)cdsti);
263-
}
264-
265-
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
266-
convert_f16_f32((const half *)cxi, (float *)cdsti);
267-
}
268-
269-
static __device__ void cpy_1_bf16_f16(const char * cxi, char * cdsti) {
270-
convert_bf16_f16((const nv_bfloat16 *)cxi, (half *)cdsti);
271-
}
272-
273-
static __device__ void cpy_1_bf16_f32(const char * cxi, char * cdsti) {
274-
convert_bf16_f32((const nv_bfloat16 *)cxi, (float *)cdsti);
275-
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 24 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

1010
template <cpy_kernel_t cpy_1>
11-
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
11+
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
1212
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1313
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1414
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
@@ -139,53 +139,25 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139139
#endif
140140
}
141141

142-
static void ggml_cpy_f16_f32_cuda(
142+
template<typename src_t, typename dst_t>
143+
static void ggml_cpy_flt_cuda(
143144
const char * cx, char * cdst, const int ne,
144145
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145146
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146147

147148
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
148-
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149150
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150151
}
151152

152-
static void ggml_cpy_bf16_f32_cuda(
153+
template<typename src_t>
154+
static void ggml_cpy_to_f16_cuda(
153155
const char * cx, char * cdst, const int ne,
154156
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155157
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156158

157159
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
158-
cpy_f32_f16<cpy_1_bf16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
159-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160-
}
161-
162-
static void ggml_cpy_f32_f32_cuda(
163-
const char * cx, char * cdst, const int ne,
164-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166-
167-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
168-
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
169-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170-
}
171-
172-
static void ggml_cpy_f32_bf16_cuda(
173-
const char * cx, char * cdst, const int ne,
174-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176-
177-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
178-
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
179-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180-
}
181-
182-
static void ggml_cpy_f32_f16_cuda(
183-
const char * cx, char * cdst, const int ne,
184-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
185-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
186-
187-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
188-
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
160+
cpy_flt<cpy_1_to_f16<src_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
189161
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
190162
}
191163

@@ -323,27 +295,7 @@ static void ggml_cpy_f16_f16_cuda(
323295
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
324296

325297
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
326-
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
327-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
328-
}
329-
330-
static void ggml_cpy_f16_bf16_cuda(
331-
const char * cx, char * cdst, const int ne,
332-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
333-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
334-
335-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
336-
cpy_f32_f16<cpy_1_f16_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
337-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
338-
}
339-
340-
static void ggml_cpy_bf16_f16_cuda(
341-
const char * cx, char * cdst, const int ne,
342-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
343-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
344-
345-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
346-
cpy_f32_f16<cpy_1_bf16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
298+
cpy_flt<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
347299
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
348300
}
349301

@@ -402,11 +354,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
402354
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
403355
}
404356
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
405-
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
357+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
406358
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
407-
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
359+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
408360
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
409-
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
361+
ggml_cpy_to_f16_cuda<float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
410362
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
411363
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
412364
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -435,16 +387,16 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
435387
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
436388
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
437389
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
438-
ggml_cpy_f16_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
390+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
439391
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
440-
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
392+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441393
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
442394
// Pure copy, doesn't need its own BF16 function
443395
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
444396
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
445-
ggml_cpy_bf16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
397+
ggml_cpy_to_f16_cuda<nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
446398
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
447-
ggml_cpy_bf16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
399+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
448400
} else {
449401
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
450402
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -469,11 +421,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
469421
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
470422
return nullptr;
471423
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
472-
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
424+
return (void*) cpy_flt<cpy_1_flt<float, float>>;
473425
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
474-
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
426+
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
475427
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
476-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
428+
return (void*) cpy_flt<cpy_1_to_f16<float>>;
477429
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
478430
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
479431
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -497,17 +449,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
497449
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
498450
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
499451
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500-
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
452+
return (void*) cpy_flt<cpy_1_f16_f16>;
501453
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
502-
return (void*) cpy_f32_f16<cpy_1_f16_bf16>;
454+
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
503455
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
504-
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
456+
return (void*) cpy_flt<cpy_1_flt<half, float>>;
505457
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
506-
return (void*) cpy_f32_f16<cpy_1_bf16_f16>;
458+
return (void*) cpy_flt<cpy_1_to_f16<nv_bfloat16>>;
507459
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
508-
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
460+
return (void*) cpy_flt<cpy_1_f16_f16>;
509461
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
510-
return (void*) cpy_f32_f16<cpy_1_bf16_f32>;
462+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
511463
} else {
512464
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
513465
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ __device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
1111

1212
template<>
1313
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
14-
convert_f32_f16(src_f, dst_h);
14+
convert_to_f16(src_f, dst_h);
1515
}
1616

1717
template<>
1818
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
19-
convert_f32_bf16(src_f, dst_b);
19+
convert_to_flt(src_f, dst_b);
2020
}
2121

2222
template<>
2323
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
24-
convert_f32_f32(src_f, dst_f);
24+
convert_to_flt(src_f, dst_f);
2525
}
2626

2727
// Generic quantized set_rows kernel template

0 commit comments

Comments
 (0)