Skip to content

Commit 6dcde5a

Browse files
swap arg order, consolidate
1 parent 173df5e commit 6dcde5a

File tree

6 files changed

+17
-55
lines changed

6 files changed

+17
-55
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
3131
dequantize_kernel(vx, ib, iqs, v);
3232

3333
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34-
y[iy0 + 0] = ggml_cuda_cast<float, dst_t>(v.x);
35-
y[iy0 + y_offset] = ggml_cuda_cast<float, dst_t>(v.y);
34+
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
35+
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
3636
}
3737

3838
template <bool need_check>
@@ -630,7 +630,7 @@ static __global__ void convert_unary(
630630

631631
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
632632
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
633-
y[iy] = ggml_cuda_cast<src_t, dst_t>(x[ix]);
633+
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
634634
}
635635

636636
template <typename src_t, typename dst_t>

ggml/src/ggml-cuda/convert.cuh

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,41 +30,15 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
3030
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
3131
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
3232

33-
template<typename src_t, typename dest_t>
34-
__host__ __device__ inline dest_t ggml_cuda_cast(src_t x) {
35-
if constexpr (std::is_same_v<src_t, dest_t>) {
33+
template<typename dst_t, typename src_t>
34+
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
35+
if constexpr (std::is_same_v<dst_t, src_t>) {
3636
return x;
37+
} else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
38+
return __float2bfloat16(float(x));
39+
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
40+
return __bfloat162float(x);
3741
} else {
3842
return float(x);
3943
}
4044
}
41-
42-
template<>
43-
__host__ __device__ inline float ggml_cuda_cast<nv_bfloat16, float>(nv_bfloat16 x) {
44-
return __bfloat162float(x);
45-
}
46-
47-
template<>
48-
__host__ __device__ inline nv_bfloat16 ggml_cuda_cast<float, nv_bfloat16>(float x) {
49-
return __float2bfloat16(x);
50-
}
51-
52-
template<>
53-
__host__ __device__ inline half ggml_cuda_cast<nv_bfloat16, half>(nv_bfloat16 x) {
54-
return half(__bfloat162float(x));
55-
}
56-
57-
template<>
58-
__host__ __device__ inline nv_bfloat16 ggml_cuda_cast<half, nv_bfloat16>(half x) {
59-
return __float2bfloat16(float(x));
60-
}
61-
62-
template<>
63-
__host__ __device__ inline int ggml_cuda_cast<nv_bfloat16, int>(nv_bfloat16 x) {
64-
return int(__bfloat162float(x));
65-
}
66-
67-
template<>
68-
__host__ __device__ inline nv_bfloat16 ggml_cuda_cast<int, nv_bfloat16>(int x) {
69-
return __float2bfloat16(float(x));
70-
}

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
#include "ggml-common.h"
44
#include "convert.cuh"
55

6-
template<typename src_t, typename dst_t>
7-
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
8-
*dst = ggml_cuda_cast<src_t, dst_t>(*src);
9-
}
10-
116
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
127
if (x <= val[0]) return 0;
138
if (x >= val[n-1]) return n-1;
@@ -218,5 +213,5 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
218213

219214
template<typename src_t, typename dst_t>
220215
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
221-
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
216+
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
222217
}

ggml/src/ggml-cuda/getrows.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ static __global__ void k_get_rows(
3535
dfloat2 v;
3636
dequantize_kernel(src0_row, ib, iqs, v);
3737

38-
dst_row[iybs + iqs + 0] = ggml_cuda_cast<float, dst_t>(v.x);
39-
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<float, dst_t>(v.y);
38+
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
39+
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
4040
}
4141

4242
template<typename src0_t, typename dst_t>
@@ -63,7 +63,7 @@ static __global__ void k_get_rows_float(
6363
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
6464
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6565

66-
dst_row[i00] = ggml_cuda_cast<src0_t, dst_t>(src0_row[i00]);
66+
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
6767
}
6868

6969
template<typename grad_t, typename dst_t>

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ static __global__ void mul_mat_vec_f(
9494
#pragma unroll
9595
for (int j = 0; j < ncols_dst; ++j) {
9696
const float2 tmpy = y2[j*stride_col_y2 + col2];
97-
sumf[j] += ggml_cuda_cast<nv_bfloat16, float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98-
sumf[j] += ggml_cuda_cast<nv_bfloat16, float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
97+
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98+
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
9999
}
100100
}
101101
} else {

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33

44
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
55

6-
template<typename src_t, typename dst_t>
7-
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8-
convert_flt(src_f, dst_f);
9-
}
10-
116
// Generic quantized set_rows kernel template
127
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
138
static __global__ void k_set_rows_quant(
@@ -117,9 +112,7 @@ static __global__ void k_set_rows(
117112
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
118113
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
119114

120-
const src_t* src_elem = src0_row + i00;
121-
dst_t* dst_elem = dst_row_ptr + i00;
122-
set_rows_1(src_elem, dst_elem);
115+
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
123116

124117
GGML_UNUSED(ne10);
125118
GGML_UNUSED(ne13);

0 commit comments

Comments
 (0)