Skip to content

Commit a18705a

Browse files
committed
Merge branch 'master' into gg/server-fix-n-past
2 parents 1e1b4af + b9ce940 commit a18705a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2299
-1128
lines changed

CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
/ggml/src/ggml-impl.h @ggerganov @slaren
6666
/ggml/src/ggml-metal/ @ggerganov
6767
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
68-
/ggml/src/ggml-hexagon/ @max-krasnyansky
68+
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
6969
/ggml/src/ggml-opt.cpp @JohannesGaessler
7070
/ggml/src/ggml-quants.* @ggerganov
7171
/ggml/src/ggml-rpc/ @rgerganov

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Legend:
7979
| REPEAT |||| 🟡 || 🟡 || 🟡 ||
8080
| REPEAT_BACK ||||||||||
8181
| RMS_NORM ||||| 🟡 |||||
82-
| RMS_NORM_BACK ||||||| |||
82+
| RMS_NORM_BACK ||||||| |||
8383
| RMS_NORM_MUL_ADD ||||||||||
8484
| ROLL ||||||||||
8585
| ROPE || 🟡 ||||||||

docs/ops/SYCL.csv

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5637,25 +5637,25 @@
56375637
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL"
56385638
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL"
56395639
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL"
5640-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL"
5640+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL"
56415641
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56425642
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL"
56435643
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL"
56445644
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL"
56455645
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL"
5646-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL"
5646+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL"
56475647
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56485648
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL"
56495649
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL"
56505650
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL"
56515651
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL"
5652-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL"
5652+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL"
56535653
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56545654
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL"
56555655
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL"
56565656
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL"
56575657
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL"
5658-
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL"
5658+
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL"
56595659
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
56605660
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
56615661
"SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL"

ggml/src/ggml-cuda/common.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
625625
// and a shift:
626626
//
627627
// n/d = (mulhi(n, mp) + n) >> L;
628-
static const uint3 init_fastdiv_values(uint32_t d) {
629-
GGML_ASSERT(d != 0);
628+
static const uint3 init_fastdiv_values(uint64_t d_64) {
629+
GGML_ASSERT(d_64 != 0);
630+
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
631+
632+
uint32_t d = (uint32_t)d_64;
630633

631634
// compute L = ceil(log2(d));
632635
uint32_t L = 0;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "ggml-cuda/upscale.cuh"
5151
#include "ggml-cuda/wkv.cuh"
5252
#include "ggml-cuda/gla.cuh"
53+
#include "ggml-cuda/set.cuh"
5354
#include "ggml-cuda/set-rows.cuh"
5455
#include "ggml-cuda/pad_reflect_1d.cuh"
5556
#include "ggml.h"
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24162417
case GGML_OP_SET_ROWS:
24172418
ggml_cuda_op_set_rows(ctx, dst);
24182419
break;
2420+
case GGML_OP_SET:
2421+
ggml_cuda_op_set(ctx, dst);
2422+
break;
24192423
case GGML_OP_DUP:
24202424
ggml_cuda_dup(ctx, dst);
24212425
break;
@@ -2974,7 +2978,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29742978
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
29752979

29762980
if (ops.size() == topk_moe_ops_with_norm.size() &&
2977-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
2981+
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
29782982
ggml_tensor * softmax = cgraph->nodes[node_idx];
29792983
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
29802984

@@ -2993,7 +2997,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29932997
}
29942998

29952999
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
2996-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
3000+
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
29973001
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
29983002
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
29993003

@@ -3114,9 +3118,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31143118
// With the use of CUDA graphs, the execution will be performed by the graph launch.
31153119
if (!use_cuda_graph || cuda_graph_update_required) {
31163120

3121+
[[maybe_unused]] int prev_i = 0;
3122+
31173123
for (int i = 0; i < cgraph->n_nodes; i++) {
31183124
ggml_tensor * node = cgraph->nodes[i];
31193125

3126+
3127+
#ifdef GGML_CUDA_DEBUG
3128+
const int nodes_fused = i - prev_i - 1;
3129+
prev_i = i;
3130+
if (nodes_fused > 0) {
3131+
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
3132+
}
3133+
#endif
3134+
31203135
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
31213136
continue;
31223137
}
@@ -3842,6 +3857,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
38423857
op->src[0]->type == GGML_TYPE_F32 &&
38433858
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
38443859
} break;
3860+
case GGML_OP_SET:
3861+
{
3862+
const ggml_type t = op->type;
3863+
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
3864+
t == op->src[0]->type &&
3865+
t == op->src[1]->type;
3866+
} break;
38453867
case GGML_OP_CPY:
38463868
{
38473869
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 101 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,53 @@
44
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
55

66
// Generic quantized set_rows kernel template
7-
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
8-
static __global__ void k_set_rows_quant(
9-
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
10-
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
11-
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
12-
const int64_t s01, const int64_t s02, const int64_t s03,
13-
const int64_t s10, const int64_t s11, const int64_t s12,
14-
const int64_t s1, const int64_t s2, const int64_t s3) {
15-
7+
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
8+
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
9+
const idx_t * __restrict__ src1,
10+
block_type * __restrict__ dst,
11+
const int64_t ne_total,
12+
const int64_t ne10,
13+
const int64_t ne11,
14+
const int64_t ne12,
15+
const int64_t ne13,
16+
const int64_t s01,
17+
const int64_t s02,
18+
const int64_t s03,
19+
const int64_t s10,
20+
const int64_t s11,
21+
const int64_t s12,
22+
const int64_t s1,
23+
const int64_t s2,
24+
const int64_t s3,
25+
const uint3 ne00,
26+
const uint3 ne01,
27+
const uint3 ne02,
28+
const uint3 ne11_fd,
29+
const uint3 ne12_fd) {
1630
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
17-
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
1831

1932
if (i >= ne_total) {
2033
return;
2134
}
2235

2336
const int64_t i_base = i * qk;
24-
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
25-
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
26-
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
27-
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
37+
uint32_t tmp = (uint32_t) i_base;
38+
uint2 div_mod;
39+
40+
div_mod = fast_div_modulo(tmp, ne00);
41+
const int64_t i00 = div_mod.y;
42+
tmp = div_mod.x;
2843

29-
const int64_t i12 = i03 % ne12;
30-
const int64_t i11 = i02 % ne11;
44+
div_mod = fast_div_modulo(tmp, ne01);
45+
const int64_t i01 = div_mod.y;
46+
tmp = div_mod.x;
47+
48+
div_mod = fast_div_modulo(tmp, ne02);
49+
const int64_t i02 = div_mod.y;
50+
const int64_t i03 = div_mod.x;
51+
52+
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
53+
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
3154
const int64_t i10 = i01;
3255

3356
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
4164
quantize_func(src_block, dst_block);
4265

4366
GGML_UNUSED(ne10);
67+
GGML_UNUSED(ne11);
68+
GGML_UNUSED(ne12);
4469
GGML_UNUSED(ne13);
4570
}
4671

@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
7196
const int64_t s2 = nb2;
7297
const int64_t s3 = nb3;
7398

74-
if (ne_total > 0) {
99+
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
100+
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
101+
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
102+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
103+
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
104+
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
105+
75106
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
76-
src0_d, src1_d, dst_d,
77-
ne00, ne01, ne02, ne03,
78-
ne10, ne11, ne12, ne13,
79-
s01, s02, s03,
80-
s10, s11, s12,
81-
s1, s2, s3);
107+
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
108+
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
82109
}
83110
}
84111

85-
template<typename src_t, typename idx_t, typename dst_t>
86-
static __global__ void k_set_rows(
87-
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
88-
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
89-
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
90-
const int64_t s01, const int64_t s02, const int64_t s03,
91-
const int64_t s10, const int64_t s11, const int64_t s12,
92-
const int64_t s1, const int64_t s2, const int64_t s3) {
93-
112+
template <typename src_t, typename idx_t, typename dst_t>
113+
static __global__ void k_set_rows(const src_t * __restrict__ src0,
114+
const idx_t * __restrict__ src1,
115+
dst_t * __restrict__ dst,
116+
const int64_t ne_total,
117+
const int64_t ne10,
118+
const int64_t ne11,
119+
const int64_t ne12,
120+
const int64_t ne13,
121+
const int64_t s01,
122+
const int64_t s02,
123+
const int64_t s03,
124+
const int64_t s10,
125+
const int64_t s11,
126+
const int64_t s12,
127+
const int64_t s1,
128+
const int64_t s2,
129+
const int64_t s3,
130+
const uint3 ne00,
131+
const uint3 ne01,
132+
const uint3 ne02,
133+
const uint3 ne11_fd,
134+
const uint3 ne12_fd) {
94135
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
95-
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
96136

97137
if (i >= ne_total) {
98138
return;
99139
}
100140

101-
const int64_t i03 = i / (ne00 * ne01 * ne02);
102-
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
103-
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
104-
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
141+
uint32_t tmp = (uint32_t) i;
142+
uint2 div_mod;
143+
144+
div_mod = fast_div_modulo(tmp, ne00);
145+
const int64_t i00 = div_mod.y;
146+
tmp = div_mod.x;
105147

106-
const int64_t i12 = i03 % ne12;
107-
const int64_t i11 = i02 % ne11;
148+
div_mod = fast_div_modulo(tmp, ne01);
149+
const int64_t i01 = div_mod.y;
150+
tmp = div_mod.x;
151+
152+
div_mod = fast_div_modulo(tmp, ne02);
153+
const int64_t i02 = div_mod.y;
154+
const int64_t i03 = div_mod.x;
155+
156+
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
157+
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
108158
const int64_t i10 = i01;
109159

110160
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
115165
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
116166

117167
GGML_UNUSED(ne10);
168+
GGML_UNUSED(ne11);
169+
GGML_UNUSED(ne12);
118170
GGML_UNUSED(ne13);
119171
}
120172

@@ -144,14 +196,16 @@ static void set_rows_cuda(
144196
const int64_t s2 = nb2/sizeof(dst_t);
145197
const int64_t s3 = nb3/sizeof(dst_t);
146198

147-
if (ne_total > 0) {
148-
k_set_rows<<<grid_size, block_size, 0, stream>>>(
149-
src0_d, src1_d, dst_d,
150-
ne00, ne01, ne02, ne03,
151-
ne10, ne11, ne12, ne13,
152-
s01, s02, s03,
153-
s10, s11, s12,
154-
s1, s2, s3);
199+
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
200+
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
201+
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
202+
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
203+
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
204+
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
205+
206+
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
207+
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
208+
ne11_fd, ne12_fd);
155209
}
156210
}
157211

ggml/src/ggml-cuda/set.cu

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "set.cuh"
2+
#include "cpy.cuh"
3+
4+
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5+
const ggml_tensor * src0 = dst->src[0];
6+
const ggml_tensor * src1 = dst->src[1];
7+
8+
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
9+
GGML_ASSERT(src1->type == src0->type);
10+
GGML_ASSERT(dst ->type == src0->type);
11+
12+
GGML_ASSERT(ggml_is_contiguous(dst));
13+
GGML_ASSERT(ggml_is_contiguous(src0));
14+
GGML_ASSERT(ggml_is_contiguous(src1));
15+
16+
const size_t nb1 = ((int32_t *) dst->op_params)[0];
17+
const size_t nb2 = ((int32_t *) dst->op_params)[1];
18+
const size_t nb3 = ((int32_t *) dst->op_params)[2];
19+
const size_t offset = ((int32_t *) dst->op_params)[3];
20+
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
21+
22+
if (!inplace) {
23+
ggml_cuda_cpy(ctx, src0, dst);
24+
}
25+
26+
ggml_tensor dst_view = *dst;
27+
dst_view.data = (void *)((char *)dst->data + offset);
28+
dst_view.ne[0] = src1->ne[0];
29+
dst_view.ne[1] = src1->ne[1];
30+
dst_view.ne[2] = src1->ne[2];
31+
dst_view.ne[3] = src1->ne[3];
32+
33+
dst_view.nb[0] = ggml_element_size(dst);
34+
dst_view.nb[1] = nb1;
35+
dst_view.nb[2] = nb2;
36+
dst_view.nb[3] = nb3;
37+
38+
ggml_cuda_cpy(ctx, src1, &dst_view);
39+
}

ggml/src/ggml-cuda/set.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
#define CUDA_SET_BLOCK_SIZE 256
6+
7+
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)