Skip to content

Commit 1e3a92b

Browse files
committed
CUDA/HIP: replace further casts with ggml_cuda_convert_val
1 parent 4738fa6 commit 1e3a92b

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

ggml/src/ggml-cuda/getrows.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "getrows.cuh"
22
#include "dequantize.cuh"
3+
#include "convert.cuh"
34

45
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
56
static __global__ void k_get_rows(
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
3435
dfloat2 v;
3536
dequantize_kernel(src0_row, ib, iqs, v);
3637

37-
dst_row[iybs + iqs + 0] = float(v.x);
38-
dst_row[iybs + iqs + y_offset] = float(v.y);
38+
dst_row[iybs + iqs + 0] = ggml_cuda_convert_val<float, dst_t>(v.x);
39+
dst_row[iybs + iqs + y_offset] = ggml_cuda_convert_val<float, dst_t>(v.y);
3940
}
4041

4142
template<typename src0_t, typename dst_t>
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
6263
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
6364
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6465

65-
dst_row[i00] = float(src0_row[i00]);
66+
dst_row[i00] = ggml_cuda_convert_val<src0_t, dst_t>(src0_row[i00]);
6667
}
6768

6869
template<typename grad_t, typename dst_t>

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ggml.h"
22
#include "common.cuh"
3+
#include "convert.cuh"
34
#include "mmvf.cuh"
45

56
template <typename T, typename type_acc, int ncols_dst, int block_size>
@@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f(
9394
#pragma unroll
9495
for (int j = 0; j < ncols_dst; ++j) {
9596
const float2 tmpy = y2[j*stride_col_y2 + col2];
96-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
97+
sumf[j] += ggml_cuda_convert_val<nv_bfloat16, float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98+
sumf[j] += ggml_cuda_convert_val<nv_bfloat16, float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
9899
}
99100
}
100101
} else {

0 commit comments

Comments
 (0)