File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change 11#include " getrows.cuh"
22#include " dequantize.cuh"
3+ #include " convert.cuh"
34
45template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
56static __global__ void k_get_rows (
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
3435 dfloat2 v;
3536 dequantize_kernel (src0_row, ib, iqs, v);
3637
37- dst_row[iybs + iqs + 0 ] = float (v.x );
38- dst_row[iybs + iqs + y_offset] = float (v.y );
38+ dst_row[iybs + iqs + 0 ] = ggml_cuda_convert_val< float , dst_t > (v.x );
39+ dst_row[iybs + iqs + y_offset] = ggml_cuda_convert_val< float , dst_t > (v.y );
3940}
4041
4142template <typename src0_t , typename dst_t >
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
6263 dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
6364 const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6465
65- dst_row[i00] = float (src0_row[i00]);
66+ dst_row[i00] = ggml_cuda_convert_val< src0_t , dst_t > (src0_row[i00]);
6667}
6768
6869template <typename grad_t , typename dst_t >
Original file line number Diff line number Diff line change 11#include " ggml.h"
22#include " common.cuh"
3+ #include " convert.cuh"
34#include " mmvf.cuh"
45
56template <typename T, typename type_acc, int ncols_dst, int block_size>
@@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f(
9394#pragma unroll
9495 for (int j = 0 ; j < ncols_dst; ++j) {
9596 const float2 tmpy = y2[j*stride_col_y2 + col2];
96- sumf[j] += float (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]) * tmpy.x ;
97- sumf[j] += float (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]) * tmpy.y ;
97+ sumf[j] += ggml_cuda_convert_val<nv_bfloat16, float > (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]) * tmpy.x ;
98+ sumf[j] += ggml_cuda_convert_val<nv_bfloat16, float > (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]) * tmpy.y ;
9899 }
99100 }
100101 } else {
You can’t perform that action at this time.
0 commit comments