Skip to content

Commit 5ba36f6

Browse files
HIP: Cleanup hipification header (#15285)
add expicit conversion operator to support older versions of rocm Switch over to hip_bf16 from legacy hip_bfloat16 Simplify RDNA3 define Reduce swap over of new hipblas api to rocm 6.5 as this version is used for rocm 7.0 previews --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent b204a5a commit 5ba36f6

File tree

7 files changed

+32
-33
lines changed

7 files changed

+32
-33
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
3131
dequantize_kernel(vx, ib, iqs, v);
3232

3333
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34-
y[iy0 + 0] = float(v.x);
35-
y[iy0 + y_offset] = float(v.y);
34+
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
35+
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
3636
}
3737

3838
template <bool need_check>
@@ -630,7 +630,7 @@ static __global__ void convert_unary(
630630

631631
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
632632
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
633-
y[iy] = float(x[ix]);
633+
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
634634
}
635635

636636
template <typename src_t, typename dst_t>

ggml/src/ggml-cuda/convert.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,16 @@ typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
2929
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
3030
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
3131
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
32+
33+
template<typename dst_t, typename src_t>
34+
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
35+
if constexpr (std::is_same_v<dst_t, src_t>) {
36+
return x;
37+
} else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
38+
return __float2bfloat16(float(x));
39+
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
40+
return __bfloat162float(x);
41+
} else {
42+
return float(x);
43+
}
44+
}

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
#pragma once
22

33
#include "ggml-common.h"
4-
5-
template<typename src_t, typename dst_t>
6-
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
7-
if constexpr (std::is_same_v<src_t, dst_t>) {
8-
*dst = *src;
9-
} else {
10-
*dst = float(*src);
11-
}
12-
}
4+
#include "convert.cuh"
135

146
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
157
if (x <= val[0]) return 0;
@@ -221,5 +213,5 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
221213

222214
template<typename src_t, typename dst_t>
223215
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
224-
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
216+
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
225217
}

ggml/src/ggml-cuda/getrows.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "getrows.cuh"
22
#include "dequantize.cuh"
3+
#include "convert.cuh"
34

45
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
56
static __global__ void k_get_rows(
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
3435
dfloat2 v;
3536
dequantize_kernel(src0_row, ib, iqs, v);
3637

37-
dst_row[iybs + iqs + 0] = float(v.x);
38-
dst_row[iybs + iqs + y_offset] = float(v.y);
38+
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
39+
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
3940
}
4041

4142
template<typename src0_t, typename dst_t>
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
6263
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
6364
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6465

65-
dst_row[i00] = float(src0_row[i00]);
66+
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
6667
}
6768

6869
template<typename grad_t, typename dst_t>

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ggml.h"
22
#include "common.cuh"
3+
#include "convert.cuh"
34
#include "mmvf.cuh"
45

56
template <typename T, typename type_acc, int ncols_dst, int block_size>
@@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f(
9394
#pragma unroll
9495
for (int j = 0; j < ncols_dst; ++j) {
9596
const float2 tmpy = y2[j*stride_col_y2 + col2];
96-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
97+
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98+
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
9899
}
99100
}
100101
} else {

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33

44
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
55

6-
template<typename src_t, typename dst_t>
7-
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8-
convert_flt(src_f, dst_f);
9-
}
10-
116
// Generic quantized set_rows kernel template
127
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
138
static __global__ void k_set_rows_quant(
@@ -117,9 +112,7 @@ static __global__ void k_set_rows(
117112
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
118113
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
119114

120-
const src_t* src_elem = src0_row + i00;
121-
dst_t* dst_elem = dst_row_ptr + i00;
122-
set_rows_1(src_elem, dst_elem);
115+
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
123116

124117
GGML_UNUSED(ne10);
125118
GGML_UNUSED(ne13);

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <hip/hip_runtime.h>
55
#include <hipblas/hipblas.h>
66
#include <hip/hip_fp16.h>
7-
#include <hip/hip_bfloat16.h>
7+
#include <hip/hip_bf16.h>
88

99
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1010
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
@@ -135,7 +135,7 @@
135135
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
136136
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
137137

138-
#if HIP_VERSION >= 70000000
138+
#if HIP_VERSION >= 60500000
139139
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
140140
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
141141
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
@@ -147,7 +147,7 @@
147147
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
148148
#define cublasComputeType_t hipblasDatatype_t
149149
#define cudaDataType_t hipblasDatatype_t
150-
#endif // HIP_VERSION >= 7000000
150+
#endif // HIP_VERSION >= 6050000
151151

152152
#if !defined(__HIP_PLATFORM_AMD__)
153153
#error "The HIP backend supports only AMD targets"
@@ -179,8 +179,7 @@
179179
#define RDNA4
180180
#endif
181181

182-
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
183-
defined(__gfx1150__) || defined(__gfx1151__)
182+
#if defined(__GFX11__)
184183
#define RDNA3
185184
#endif
186185

@@ -197,8 +196,8 @@
197196
#define __has_builtin(x) 0
198197
#endif
199198

200-
typedef hip_bfloat16 nv_bfloat16;
201-
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
199+
typedef __hip_bfloat16 nv_bfloat16;
200+
typedef __hip_bfloat162 nv_bfloat162;
202201

203202
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
204203
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));

0 commit comments

Comments
 (0)