Skip to content

Commit 3782433

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 8d26b72 + 75cbdd3 commit 3782433

File tree

11 files changed

+339
-46
lines changed

11 files changed

+339
-46
lines changed

ggml/src/ggml-alloc.c

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
226226
}
227227

228228
if (best_fit_block == -1) {
229-
// no suitable block found, try the last block (this will grow a chunks size)
229+
// no suitable block found, try the last block (this may grow a chunks size)
230+
int64_t best_reuse = INT64_MIN;
230231
for (int c = 0; c < alloc->n_chunks; ++c) {
231232
struct tallocr_chunk * chunk = alloc->chunks[c];
232233
if (chunk->n_free_blocks > 0) {
233234
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
234235
max_avail = MAX(max_avail, block->size);
235-
if (block->size >= size) {
236+
int64_t reuse_factor = chunk->max_size - block->offset - size;
237+
// reuse_factor < 0 : amount of extra memory that needs to be allocated
238+
// reuse_factor = 0 : allocated free space exactly matches tensor size
239+
// reuse_factor > 0 : superfluous memory that will remain unused
240+
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
241+
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
242+
if (block->size >= size && (better_reuse || better_fit)) {
236243
best_fit_chunk = c;
237244
best_fit_block = chunk->n_free_blocks - 1;
238-
break;
245+
best_reuse = reuse_factor;
239246
}
240247
}
241248
}
@@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
268275
#ifdef GGML_ALLOCATOR_DEBUG
269276
add_allocated_tensor(alloc, addr, tensor);
270277
size_t cur_max = addr.offset + size;
271-
if (cur_max > alloc->max_size[addr.chunk]) {
278+
if (cur_max > chunk->max_size) {
272279
// sort allocated_tensors by chunk/offset
273280
for (int i = 0; i < 1024; i++) {
274281
for (int j = i + 1; j < 1024; j++) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,7 +2976,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29762976
if (ops.size() == topk_moe_ops_with_norm.size() &&
29772977
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
29782978
ggml_tensor * softmax = cgraph->nodes[node_idx];
2979-
ggml_tensor * weights = cgraph->nodes[node_idx+8];
2979+
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
29802980

29812981
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
29822982
return true;
@@ -2986,7 +2986,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29862986
if (ops.size() == topk_moe_ops.size() &&
29872987
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
29882988
ggml_tensor * softmax = cgraph->nodes[node_idx];
2989-
ggml_tensor * weights = cgraph->nodes[node_idx+4];
2989+
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
29902990
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
29912991
return true;
29922992
}
@@ -3125,17 +3125,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31253125
if (!disable_fusion) {
31263126

31273127
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
3128-
ggml_tensor * weights = cgraph->nodes[i+8];
3129-
ggml_tensor * selected_experts = cgraph->nodes[i+3];
3128+
ggml_tensor * weights = cgraph->nodes[i + 9];
3129+
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3130+
ggml_tensor * clamp = cgraph->nodes[i + 7];
31303131
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
3131-
/*delayed softmax*/ false);
3132-
i += 8;
3132+
/*delayed softmax*/ false, clamp);
3133+
i += 9;
31333134
continue;
31343135
}
31353136

31363137
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
3137-
ggml_tensor * weights = cgraph->nodes[i+4];
3138-
ggml_tensor * selected_experts = cgraph->nodes[i+3];
3138+
ggml_tensor * weights = cgraph->nodes[i + 4];
3139+
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
31393140
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
31403141
/*delayed softmax*/ false);
31413142
i += 4;

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ggml.h"
33
#include "topk-moe.cuh"
44

5+
#include <cmath>
56
#include <initializer_list>
67

78
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6364
float * weights,
6465
int32_t * ids,
6566
const int n_rows,
66-
const int n_expert_used) {
67+
const int n_expert_used,
68+
const float clamp_val) {
6769
const int row = blockIdx.x * blockDim.y + threadIdx.y;
6870
if (row >= n_rows) {
6971
return;
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
139141

140142
if constexpr (with_norm) {
141143
wt_sum = warp_reduce_sum(wt_sum);
144+
wt_sum = max(wt_sum, clamp_val);
142145
const float inv_sum = 1.0f / wt_sum;
143146

144147
for (int i = 0; i < experts_per_thread; i++) {
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
157160
weights[idx] = output_weights[i];
158161
}
159162
}
163+
164+
if (!with_norm) {
165+
GGML_UNUSED(clamp_val);
166+
}
160167
}
161168

162169
template <bool with_norm, bool delayed_softmax = false>
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
166173
int32_t * ids,
167174
const int n_rows,
168175
const int n_expert,
169-
const int n_expert_used) {
176+
const int n_expert_used,
177+
const float clamp_val) {
170178
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
171-
172179
const int rows_per_block = 4;
173180
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
174181
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
177184
switch (n_expert) {
178185
case 1:
179186
topk_moe_cuda<1, with_norm, delayed_softmax>
180-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
187+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
181188
break;
182189
case 2:
183190
topk_moe_cuda<2, with_norm, delayed_softmax>
184-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
191+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
185192
break;
186193
case 4:
187194
topk_moe_cuda<4, with_norm, delayed_softmax>
188-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
195+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
189196
break;
190197
case 8:
191198
topk_moe_cuda<8, with_norm, delayed_softmax>
192-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
199+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
193200
break;
194201
case 16:
195202
topk_moe_cuda<16, with_norm, delayed_softmax>
196-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
203+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
197204
break;
198205
case 32:
199206
topk_moe_cuda<32, with_norm, delayed_softmax>
200-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
207+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
201208
break;
202209
case 64:
203210
topk_moe_cuda<64, with_norm, delayed_softmax>
204-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
211+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
205212
break;
206213
case 128:
207214
topk_moe_cuda<128, with_norm, delayed_softmax>
208-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
215+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
209216
break;
210217
case 256:
211218
topk_moe_cuda<256, with_norm, delayed_softmax>
212-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
219+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
213220
break;
214221
case 512:
215222
topk_moe_cuda<512, with_norm, delayed_softmax>
216-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
223+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
217224
break;
218225
default:
219226
GGML_ASSERT(false && "fatal error");
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
226233
ggml_tensor * weights,
227234
ggml_tensor * ids,
228235
const bool with_norm,
229-
const bool delayed_softmax) {
236+
const bool delayed_softmax,
237+
ggml_tensor * clamp) {
230238
GGML_ASSERT(logits->type == GGML_TYPE_F32);
231239
GGML_ASSERT(weights->type == GGML_TYPE_F32);
232240
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
242250

243251
const int n_expert_used = weights->ne[1];
244252

253+
float clamp_val = -INFINITY;
245254
if (with_norm) {
246-
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
255+
if (clamp) {
256+
clamp_val = ggml_get_op_params_f32(clamp, 0);
257+
}
258+
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
247259
} else {
260+
GGML_ASSERT(clamp == nullptr);
248261
if (delayed_softmax) {
249-
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
262+
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
263+
clamp_val);
250264
} else {
251-
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
265+
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
266+
clamp_val);
252267
}
253268
}
254269
}
255270

256-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
271+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
257272
float scale = 1.0f;
258273
float max_bias = 0.0f;
259274

@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
279294
return false;
280295
}
281296

297+
if (clamp) {
298+
if (clamp->op != GGML_OP_CLAMP) {
299+
return false;
300+
}
301+
float max_val = ggml_get_op_params_f32(clamp, 1);
302+
303+
if (max_val != INFINITY) {
304+
return false;
305+
}
306+
}
307+
308+
282309
return true;
283310
}
284311

285312
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
286313
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287314
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288-
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
315+
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
316+
GGML_OP_RESHAPE };
289317

290318
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291319
GGML_OP_VIEW, GGML_OP_GET_ROWS };

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
88
ggml_tensor * weights,
99
ggml_tensor * ids,
1010
const bool with_norm,
11-
const bool delayed_softmax = false);
11+
const bool delayed_softmax = false,
12+
ggml_tensor * weight_clamp = nullptr);
1213

13-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
14+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
1415

1516
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "pad.hpp"
3333
#include "quantize.hpp"
3434
#include "quants.hpp"
35+
#include "roll.hpp"
3536
#include "rope.hpp"
3637
#include "set_rows.hpp"
3738
#include "softmax.hpp"

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "ggml-sycl/set.hpp"
4949
#include "ggml-sycl/sycl_hw.hpp"
5050
#include "ggml-sycl/getrows.hpp"
51+
#include "ggml-sycl/repeat_back.hpp"
5152
#include "ggml-sycl/quantize.hpp"
5253
#include "ggml.h"
5354

@@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) {
26152616
std::exit(1);
26162617
}
26172618

2619+
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2620+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2621+
ggml_sycl_op_repeat_back(ctx, dst);
2622+
}
26182623

26192624
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
26202625
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36793684
case GGML_OP_REPEAT:
36803685
ggml_sycl_repeat(ctx, dst);
36813686
break;
3687+
case GGML_OP_REPEAT_BACK:
3688+
ggml_sycl_repeat_back(ctx, dst);
3689+
break;
36823690
case GGML_OP_GET_ROWS:
36833691
ggml_sycl_get_rows(ctx, dst);
36843692
break;
@@ -3913,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
39133921
case GGML_OP_GATED_LINEAR_ATTN:
39143922
ggml_sycl_op_gated_linear_attn(ctx, dst);
39153923
break;
3924+
case GGML_OP_ROLL:
3925+
ggml_sycl_roll(ctx, dst);
3926+
break;
39163927
case GGML_OP_ARANGE:
39173928
ggml_sycl_arange(ctx, dst);
39183929
break;
@@ -4516,6 +4527,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45164527
ggml_type src0_type = op->src[0]->type;
45174528
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
45184529
}
4530+
case GGML_OP_REPEAT_BACK:
4531+
{
4532+
ggml_type src0_type = op->src[0]->type;
4533+
return src0_type == GGML_TYPE_F32;
4534+
}
45194535
case GGML_OP_DUP:
45204536
case GGML_OP_ARGMAX:
45214537
case GGML_OP_NONE:
@@ -4586,6 +4602,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45864602
case GGML_OP_RWKV_WKV7:
45874603
case GGML_OP_GATED_LINEAR_ATTN:
45884604
return true;
4605+
case GGML_OP_ROLL:
4606+
return op->type == GGML_TYPE_F32;
45894607
case GGML_OP_ARANGE:
45904608
return op->type == GGML_TYPE_F32;
45914609
default:

ggml/src/ggml-sycl/repeat_back.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "repeat_back.hpp"
2+
3+
#include "common.hpp"
4+
5+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
6+
7+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
8+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
9+
10+
const float * src0_dd = (const float *) dst->src[0]->data;
11+
float * dst_dd = (float *) dst->data;
12+
13+
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
14+
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
15+
ne03 = dst->src[0]->ne[3];
16+
17+
const int nr0 = (int) (ne00 / ne0);
18+
const int nr1 = (int) (ne01 / ne1);
19+
const int nr2 = (int) (ne02 / ne2);
20+
const int nr3 = (int) (ne03 / ne3);
21+
22+
const size_t total = ne0 * ne1 * ne2 * ne3;
23+
const int BLOCK_SIZE = 256;
24+
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
25+
26+
queue_ptr stream = ctx.stream();
27+
28+
stream->parallel_for(
29+
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
30+
[=](sycl::nd_item<1> item_ct1) {
31+
const size_t i = item_ct1.get_global_linear_id();
32+
if (i >= total) {
33+
return;
34+
}
35+
36+
const int i0 = i % ne0;
37+
const int i1 = (i / ne0) % ne1;
38+
const int i2 = (i / (ne0 * ne1)) % ne2;
39+
const int i3 = i / (ne0 * ne1 * ne2);
40+
41+
float acc = 0.0f;
42+
43+
for (int j3 = 0; j3 < nr3; ++j3) {
44+
for (int j2 = 0; j2 < nr2; ++j2) {
45+
for (int j1 = 0; j1 < nr1; ++j1) {
46+
for (int j0 = 0; j0 < nr0; ++j0) {
47+
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
48+
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
49+
}
50+
}
51+
}
52+
}
53+
54+
dst_dd[i] = acc;
55+
});
56+
}

ggml/src/ggml-sycl/repeat_back.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_REPEAT_BACK_HPP
2+
#define GGML_SYCL_REPEAT_BACK_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_REPEAT_BACK_HPP

0 commit comments

Comments
 (0)