Skip to content

Commit b64ba1c

Browse files
committed
Address review comments
1 parent 5adf50e commit b64ba1c

File tree

4 files changed

+64
-64
lines changed

4 files changed

+64
-64
lines changed

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 26 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,14 @@
22
#include <cstdint>
33
#include <utility>
44

5-
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
6-
return b;
7-
GGML_UNUSED(a);
8-
}
9-
10-
static __device__ __forceinline__ float op_add(const float a, const float b) {
11-
return a + b;
12-
}
13-
14-
static __device__ __forceinline__ float op_sub(const float a, const float b) {
15-
return a - b;
16-
}
17-
18-
static __device__ __forceinline__ float op_mul(const float a, const float b) {
19-
return a * b;
20-
}
21-
22-
static __device__ __forceinline__ float op_div(const float a, const float b) {
23-
return a / b;
24-
}
25-
26-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... S1Ptrs>
5+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
276
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
287
const int ne0, const int ne1, const int ne2, const int ne3,
298
const int ne10, const int ne11, const int ne12, const int ne13,
309
/*int s0, */ const int s1, const int s2, const int s3,
3110
/*int s00,*/ const int s01, const int s02, const int s03,
3211
/*int s10,*/ const int s11, const int s12, const int s13,
33-
S1Ptrs... src1s) {
12+
src1_ptrs... src1s) {
3413
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
3514
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
3615
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -55,26 +34,20 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5534
const int i10 = i0 % ne10;
5635

5736
float result = src0_row ? (float) src0_row[i0] : 0.0f;
58-
59-
auto add_one = [&](const src1_t * p) {
60-
const src1_t * row = p + i_src1;
61-
result = bin_op(result, (float) row[i10]);
62-
return 0;
63-
};
64-
(void) std::initializer_list<int>{ (add_one(src1s), 0)... };
37+
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
6538

6639
dst_row[i0] = (dst_t) result;
6740
}
6841
}
6942

70-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... S1Ptrs>
43+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
7144
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
7245
const int ne0, const int ne1, const int ne2,const int ne3,
7346
const int ne10, const int ne11, const int ne12, const int ne13,
7447
/*int s0, */ const int s1, const int s2, const int s3,
7548
/*int s00,*/ const int s01, const int s02, const int s03,
7649
/*int s10,*/ const int s11, const int s12, const int s13,
77-
S1Ptrs... src1s) {
50+
src1_ptrs ... src1s) {
7851
const int i = blockDim.x*blockIdx.x + threadIdx.x;
7952

8053
const int i3 = i/(ne2*ne1*ne0);
@@ -100,13 +73,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
10073
const int i10 = i0 % ne10;
10174

10275
float result = src0_row ? (float) src0_row[i0] : 0.0f;
103-
104-
auto add_one = [&](const src1_t * p) {
105-
const src1_t * row = p + i_src1;
106-
result = bin_op(result, (float) row[i10]);
107-
return 0;
108-
};
109-
(void) std::initializer_list<int>{ (add_one(src1s), 0)... };
76+
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
11077

11178
dst_row[i0] = (dst_t) result;
11279
}
@@ -291,7 +258,8 @@ static __global__ void k_repeat_back(
291258
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
292259
}
293260

294-
template <float (*bin_op)(const float, const float), int n_fuse = 1> struct bin_bcast_cuda {
261+
template <float (*bin_op)(const float, const float), int n_fuse = 1>
262+
struct bin_bcast_cuda {
295263
template<typename src0_t, typename src1_t, typename dst_t>
296264
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
297265
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
@@ -355,26 +323,27 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
355323
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
356324
}
357325

358-
template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
326+
template <float (*op)(const float, const float), int n_fuse>
327+
static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
359328
cudaStream_t stream = ctx.stream();
360329

361330
const ggml_tensor * src0 = dst->src[0];
362331
const ggml_tensor * src1 = dst->src[1];
363332

364333
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
365-
launch_bin_bcast_pack<op_add, float, float, float>(src0, src1, dst,
334+
launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
366335
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
367336
stream, std::make_index_sequence<n_fuse>{});
368337
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
369-
launch_bin_bcast_pack<op_add, half, half, half>(src0, src1, dst,
338+
launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
370339
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
371340
stream, std::make_index_sequence<n_fuse>{});
372341
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
373-
launch_bin_bcast_pack<op_add, half, float, half>(src0, src1, dst,
342+
launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
374343
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
375344
stream, std::make_index_sequence<n_fuse>{});
376345
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
377-
launch_bin_bcast_pack<op_add, half, float, float>(src0, src1, dst,
346+
launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
378347
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
379348
stream, std::make_index_sequence<n_fuse>{});
380349
} else {
@@ -385,30 +354,32 @@ template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_
385354
}
386355
}
387356

388-
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
357+
358+
template<float (*op)(const float, const float)>
359+
void ggml_cuda_op_fused_binbcast(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
389360
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
390361

391362
switch (n_fuse) {
392363
case 2:
393-
ggml_cuda_op_fused_add_impl<2>(ctx, dst);
364+
ggml_cuda_op_fused_binbcast_impl<op, 2>(ctx, dst);
394365
break;
395366
case 3:
396-
ggml_cuda_op_fused_add_impl<3>(ctx, dst);
367+
ggml_cuda_op_fused_binbcast_impl<op, 3>(ctx, dst);
397368
break;
398369
case 4:
399-
ggml_cuda_op_fused_add_impl<4>(ctx, dst);
370+
ggml_cuda_op_fused_binbcast_impl<op, 4>(ctx, dst);
400371
break;
401372
case 5:
402-
ggml_cuda_op_fused_add_impl<5>(ctx, dst);
373+
ggml_cuda_op_fused_binbcast_impl<op, 5>(ctx, dst);
403374
break;
404375
case 6:
405-
ggml_cuda_op_fused_add_impl<6>(ctx, dst);
376+
ggml_cuda_op_fused_binbcast_impl<op, 6>(ctx, dst);
406377
break;
407378
case 7:
408-
ggml_cuda_op_fused_add_impl<7>(ctx, dst);
379+
ggml_cuda_op_fused_binbcast_impl<op, 7>(ctx, dst);
409380
break;
410381
case 8:
411-
ggml_cuda_op_fused_add_impl<8>(ctx, dst);
382+
ggml_cuda_op_fused_binbcast_impl<op, 8>(ctx, dst);
412383
break;
413384
default:
414385
GGML_ASSERT(false && "Unsupported n_fuse value");
@@ -445,3 +416,5 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
445416
} break;
446417
}
447418
}
419+
420+
template void ggml_cuda_op_fused_binbcast<op_add>(ggml_backend_cuda_context &, ggml_tensor *, int);

ggml/src/ggml-cuda/binbcast.cuh

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,34 @@
11
#include "common.cuh"
22

3+
4+
__device__ __forceinline__ float op_repeat(const float a, const float b) {
5+
return b;
6+
GGML_UNUSED(a);
7+
}
8+
9+
__device__ __forceinline__ float op_add(const float a, const float b) {
10+
return a + b;
11+
}
12+
13+
__device__ __forceinline__ float op_sub(const float a, const float b) {
14+
return a - b;
15+
}
16+
17+
__device__ __forceinline__ float op_mul(const float a, const float b) {
18+
return a * b;
19+
}
20+
21+
__device__ __forceinline__ float op_div(const float a, const float b) {
22+
return a / b;
23+
}
24+
325
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
426
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
527
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
628
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
729
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
830

9-
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
10-
1131
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
32+
33+
template<float (*op)(const float, const float)>
34+
void ggml_cuda_op_fused_binbcast(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28172817
return false;
28182818
}
28192819

2820-
if (ops.size() >= 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2820+
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
28212821
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
28222822
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
28232823
const ggml_tensor *add = nullptr;
@@ -2905,7 +2905,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29052905

29062906
if (node->op == GGML_OP_ADD) {
29072907
int n_fuse = 0;
2908-
ggml_op ops[8] = {GGML_OP_ADD};
2908+
ggml_op ops[8];
2909+
std::fill(ops, ops + 8, GGML_OP_ADD);
29092910

29102911
for (; n_fuse <= 6; ++n_fuse){
29112912
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
@@ -2926,8 +2927,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29262927
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
29272928
}
29282929
cgraph->nodes[i + n_fuse - 1]->data = node->data;
2929-
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
2930+
ggml_cuda_op_fused_binbcast<op_add>(*cuda_ctx, node, n_fuse);
29302931
i += n_fuse - 1;
2932+
29312933
continue;
29322934
}
29332935
}

ggml/src/ggml-cuda/norm.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
178178
const float scale = rsqrtf(mean + eps);
179179

180180
for (int col = tid; col < ncols; col += block_size) {
181-
if constexpr (do_multiply) {
181+
if constexpr (do_multiply && do_add) {
182+
const int mul_col = col % mul_ncols;
183+
const int add_col = col % add_ncols;
184+
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
185+
} else if constexpr (do_multiply) {
182186
const int mul_col = col % mul_ncols;
183187
dst[col] = scale * x[col] * mul[mul_col];
184-
185-
if constexpr (do_add) {
186-
const int add_col = col % add_ncols;
187-
dst[col] += add[add_col];
188-
}
188+
} else if constexpr (do_add) {
189+
const int add_col = col % add_ncols;
190+
dst[col] += add[add_col];
189191
} else {
190192
dst[col] = scale * x[col];
191193
}

0 commit comments

Comments
 (0)