Skip to content

Commit 7fa80c1

Browse files
committed
added fp16 kernel support
1 parent cc3d366 commit 7fa80c1

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

ggml/src/ggml-cuda/conv2d-mm.cu

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "conv2d-mm.cuh"
2+
#include "convert.cuh"
23

34
#include <cuda_runtime.h>
45

@@ -13,6 +14,8 @@
1314

1415
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
1516

17+
typedef uint32_t uint;
18+
1619
uint32_t ceil_div(uint32_t M, uint32_t N);
1720
int get_sm_count();
1821

@@ -69,11 +72,11 @@ __inline__ __device__ uint fastdiv(uint n, uint mp, uint L) {
6972
}
7073

7174
// --> conv_2d kernel modified to function as a matmul
72-
template <uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
75+
template <typename T, uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
7376
__global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
7477
uint NPQ,
7578
uint CRS,
76-
const float * knl_data,
79+
const T * knl_data,
7780
const float * src_data,
7881
float * dst_data) {
7982
// Each block computes a tile of the result of size BS_K*BS_NPQ
@@ -98,7 +101,8 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
98101
const uint T_y = threadIdx.x / NT_x;
99102
const uint T_x = threadIdx.x % NT_x;
100103

101-
__shared__ float Ash[BS_K * BS_CRS];
104+
// __shared__ float Ash[BS_K * BS_CRS];
105+
__shared__ T Ash[BS_K * BS_CRS];
102106
__shared__ float Bsh[BS_CRS * BS_NPQ];
103107

104108
const uint Ar = threadIdx.x / BS_CRS;
@@ -148,9 +152,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
148152
// General addressing (does not assume contiguity)
149153
//const uint32_t knl_idx = KW_idx_a + KH_idx_a*dp.nb01 + Cin_idx_a*dp.nb02 + K_idx_a*dp.nb03;
150154
// Contiguous addressing
151-
float val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)];
155+
T val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)];
152156
if (CRS_idx_a >= CRS || K_idx_a >= K) {
153-
val = 0.0;
157+
val = (T)0.0;
154158
}
155159

156160
#ifdef A_TRANS
@@ -235,9 +239,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
235239
# else
236240
uint32_t col_offset = (T_y * TS_K + T_ly);
237241
# endif
238-
regA[T_ly] = Ash[CRS_lidx * BS_K + col_offset];
242+
regA[T_ly] = ggml_cuda_cast<float>(Ash[CRS_lidx * BS_K + col_offset]);
239243
#else
240-
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx];
244+
regA[T_ly] = ggml_cuda_cast<float>(Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]);
241245
#endif
242246
}
243247
for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) {
@@ -343,16 +347,22 @@ void ggml_cuda_op_conv_2d_variant(ggml_backend_cuda_context & ctx,
343347
cudaMemcpyToSymbol(dp, &p, sizeof(Params));
344348

345349
// Kernel arguments
346-
float * src0_data = (float *) src0->data;
347350
float * src1_data = (float *) src1->data;
348351
float * dst_data = (float *) dst->data;
349352

350353
dim3 gridDim(NB_K, NB_NPQ);
351354
dim3 blockDim(WG_SIZE);
352355
cudaStream_t stream = ctx.stream();
353-
354-
mm<BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
356+
if(src0->type == GGML_TYPE_F16) {
357+
half *src0_data = (half *) src0->data;
358+
mm<half, BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
359+
<<<gridDim, blockDim, 0, stream>>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data);
360+
} else {
361+
float *src0_data = (float *) src0->data;
362+
mm<float, BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
355363
<<<gridDim, blockDim, 0, stream>>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data);
364+
}
365+
356366
}
357367

358368
void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -372,13 +382,13 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
372382
ggml_tensor * src0 = dst->src[0];
373383
ggml_tensor * src1 = dst->src[1];
374384

375-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
385+
// GGML_ASSERT(src0->type == GGML_TYPE_F32);
376386
GGML_ASSERT(src1->type == GGML_TYPE_F32);
377387
GGML_ASSERT(dst->type == GGML_TYPE_F32);
378388

379389
GGML_TENSOR_BINARY_OP_LOCALS
380390

381-
GGML_ASSERT(nb00 == sizeof(float));
391+
// GGML_ASSERT(nb00 == sizeof(float));
382392
GGML_ASSERT(nb10 == sizeof(float));
383393
GGML_ASSERT(nb0 == sizeof(float));
384394

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2463,7 +2463,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24632463
break;
24642464
case GGML_OP_CONV_2D:
24652465
if (!getenv("GGML_CUDA_USE_LEGACY_CONV") &&
2466-
(dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
2466+
(dst->src[1]->type == GGML_TYPE_F32 &&
24672467
dst->type == GGML_TYPE_F32)) {
24682468
ggml_cuda_op_conv2d_mm(ctx, dst);
24692469
} else {

0 commit comments

Comments
 (0)