Skip to content

Commit 44c8947

Browse files
committed
CUDA/HIP: add expicit conversion operator to support older versions of rocm
1 parent 3ea913f commit 44c8947

File tree

5 files changed

+56
-9
lines changed

5 files changed

+56
-9
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
3131
dequantize_kernel(vx, ib, iqs, v);
3232

3333
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34-
y[iy0 + 0] = float(v.x);
35-
y[iy0 + y_offset] = float(v.y);
34+
y[iy0 + 0] = ggml_cuda_convert_val<float, dst_t>(v.x);
35+
y[iy0 + y_offset] = ggml_cuda_convert_val<float, dst_t>(v.y);
3636
}
3737

3838
template <bool need_check>
@@ -630,7 +630,7 @@ static __global__ void convert_unary(
630630

631631
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
632632
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
633-
y[iy] = float(x[ix]);
633+
y[iy] = ggml_cuda_convert_val<src_t, dst_t>(x[ix]);
634634
}
635635

636636
template <typename src_t, typename dst_t>

ggml/src/ggml-cuda/convert.cuh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,47 @@ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
2929
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
3030
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
3131
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
32+
33+
template<typename src_t, typename dest_t>
34+
__host__ __device__ inline dest_t ggml_cuda_convert_val(src_t x) {
35+
if constexpr (std::is_same_v<src_t, dest_t>) {
36+
return x;
37+
} else {
38+
return float(x);
39+
}
40+
}
41+
42+
template<>
43+
__host__ __device__ inline float ggml_cuda_convert_val<nv_bfloat16, float>(nv_bfloat16 x) {
44+
return __bfloat162float(x);
45+
}
46+
47+
template<>
48+
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<nv_bfloat16, nv_bfloat16>(nv_bfloat16 x) {
49+
return x;
50+
}
51+
52+
template<>
53+
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<float, nv_bfloat16>(float x) {
54+
return __float2bfloat16(x);
55+
}
56+
57+
template<>
58+
__host__ __device__ inline half ggml_cuda_convert_val<nv_bfloat16, half>(nv_bfloat16 x) {
59+
return half(__bfloat162float(x));
60+
}
61+
62+
template<>
63+
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<half, nv_bfloat16>(half x) {
64+
return __float2bfloat16(float(x));
65+
}
66+
67+
template<>
68+
__host__ __device__ inline int ggml_cuda_convert_val<nv_bfloat16, int>(nv_bfloat16 x) {
69+
return int(__bfloat162float(x));
70+
}
71+
72+
template<>
73+
__host__ __device__ inline nv_bfloat16 ggml_cuda_convert_val<int, nv_bfloat16>(int x) {
74+
return __float2bfloat16(float(x));
75+
}

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#pragma once
22

33
#include "ggml-common.h"
4+
#include "convert.cuh"
45

56
template<typename src_t, typename dst_t>
67
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
78
if constexpr (std::is_same_v<src_t, dst_t>) {
89
*dst = *src;
910
} else {
10-
*dst = float(*src);
11+
*dst = ggml_cuda_convert_val<src_t, dst_t>(*src);
1112
}
1213
}
1314

ggml/src/ggml-cuda/getrows.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "getrows.cuh"
22
#include "dequantize.cuh"
3+
#include "convert.cuh"
34

45
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
56
static __global__ void k_get_rows(
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
3435
dfloat2 v;
3536
dequantize_kernel(src0_row, ib, iqs, v);
3637

37-
dst_row[iybs + iqs + 0] = float(v.x);
38-
dst_row[iybs + iqs + y_offset] = float(v.y);
38+
dst_row[iybs + iqs + 0] = ggml_cuda_convert_val<float, dst_t>(v.x);
39+
dst_row[iybs + iqs + y_offset] = ggml_cuda_convert_val<float, dst_t>(v.y);
3940
}
4041

4142
template<typename src0_t, typename dst_t>
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
6263
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
6364
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6465

65-
dst_row[i00] = float(src0_row[i00]);
66+
dst_row[i00] = ggml_cuda_convert_val<src0_t, dst_t>(src0_row[i00]);
6667
}
6768

6869
template<typename grad_t, typename dst_t>

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ggml.h"
22
#include "common.cuh"
3+
#include "convert.cuh"
34
#include "mmvf.cuh"
45

56
template <typename T, typename type_acc, int ncols_dst, int block_size>
@@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f(
9394
#pragma unroll
9495
for (int j = 0; j < ncols_dst; ++j) {
9596
const float2 tmpy = y2[j*stride_col_y2 + col2];
96-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
97+
sumf[j] += ggml_cuda_convert_val<nv_bfloat16, float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98+
sumf[j] += ggml_cuda_convert_val<nv_bfloat16, float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
9899
}
99100
}
100101
} else {

0 commit comments

Comments
 (0)