Skip to content

Commit e38e857

Browse files
author
Jeemzz
committed
draft: cuda set op
1 parent 9a53f40 commit e38e857

File tree

2 files changed

+58
-101
lines changed

2 files changed

+58
-101
lines changed

ggml/src/ggml-cuda/set.cu

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,78 @@
11
#include "ggml-cuda/common.cuh"
22
#include "set.cuh"
33

4-
static __global__ void set_f32_cuda_copy(const float * __restrict__ src1,
4+
static __global__ void set_f32_cuda_copy(const float * __restrict__ src0,
55
float * __restrict__ dst,
66
const size_t ne0,
77
const size_t ne1,
88
const size_t ne2,
99
const size_t ne3,
10-
const int offset, // element‐offset
11-
const int nb1, // stride in elements along dim1
12-
const int nb2, // stride in elements along dim2
13-
const int nb3 // stride in elements along dim3
14-
) {
10+
const size_t nb0,
11+
const size_t nb1,
12+
const size_t nb2,
13+
const size_t nb3) {
1514
const size_t total = ne0 * ne1 * ne2 * ne3;
1615
const size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
1716
if (gid >= total) {
1817
return;
1918
}
2019

21-
// unravel into 4D indices (i0 fastest, then i1, i2, i3):
22-
size_t tmp = gid;
23-
const size_t i0 = tmp % ne0;
20+
size_t tmp = gid;
21+
22+
const size_t i0 = tmp % ne0;
2423
tmp /= ne0;
2524
const size_t i1 = tmp % ne1;
2625
tmp /= ne1;
2726
const size_t i2 = tmp % ne2;
2827
tmp /= ne2;
29-
const size_t i3 = tmp; // < ne3
28+
const size_t i3 = tmp;
3029

31-
// compute flat positions with strides + offset
32-
const size_t pos = offset + i0 + i1 * (size_t) nb1 + i2 * (size_t) nb2 + i3 * (size_t) nb3;
30+
const size_t pos = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
3331

34-
dst[pos] = src1[pos];
32+
*((float *) ((char *) dst + pos)) = *((const float *) ((const char *) src0 + pos));
3533
}
3634

37-
static __global__ void set_f32_cuda(const float * __restrict__ src0,
35+
static __global__ void set_f32_cuda(const float * __restrict__ src1,
3836
float * __restrict__ dst,
39-
const size_t ne0,
40-
const size_t ne1,
41-
const size_t ne2,
42-
const size_t ne3,
43-
const int offset, // element‐offset into dst
44-
const int nb1, // stride in elements along dim1
45-
const int nb2, // stride in elements along dim2
46-
const int nb3 // stride in elements along dim3
37+
const size_t ne10,
38+
const size_t ne11,
39+
const size_t ne12,
40+
const size_t ne13,
41+
const size_t nb10,
42+
const size_t nb11,
43+
const size_t nb12,
44+
const size_t nb13,
45+
const size_t nb0,
46+
const size_t nb1,
47+
const size_t nb2,
48+
const size_t nb3,
49+
const size_t offset
50+
4751
) {
48-
// src0 is contiguous over ne0*ne1*ne2*ne3 elements
49-
const size_t total = ne0 * ne1 * ne2 * ne3;
52+
const size_t total = ne10 * ne11 * ne12 * ne13;
5053
const size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
5154
if (gid >= total) {
5255
return;
5356
}
5457

55-
// unravel gid to 4D (same as copy)
56-
size_t tmp = gid;
57-
const size_t i0 = tmp % ne0;
58-
tmp /= ne0;
59-
const size_t i1 = tmp % ne1;
60-
tmp /= ne1;
61-
const size_t i2 = tmp % ne2;
62-
tmp /= ne2;
58+
size_t tmp = gid;
59+
60+
const size_t i0 = tmp % ne10;
61+
tmp /= ne10;
62+
const size_t i1 = tmp % ne11;
63+
tmp /= ne11;
64+
const size_t i2 = tmp % ne12;
65+
tmp /= ne12;
6366
const size_t i3 = tmp;
6467

65-
// dst position has the same formula:
66-
const size_t pos = offset + i0 + i1 * (size_t) nb1 + i2 * (size_t) nb2 + i3 * (size_t) nb3;
68+
size_t dst_offset = offset + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3;
69+
size_t src1_offset = i0 * nb10 + i1 * nb11 + i2 * nb12 + i3 * nb13;
6770

68-
// src0 is contiguous: flat index = gid
69-
dst[pos] = src0[gid];
71+
*((float *) ((char *) dst + dst_offset)) = *((const float *) ((const char *) src1 + src1_offset));
7072
}
7173

7274
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
75+
// nb0 is implicitly element_size because src0 and dst are contiguous
7376
const int32_t nb1 = dst->op_params[0];
7477
const int32_t nb2 = dst->op_params[1];
7578
const int32_t nb3 = dst->op_params[2];
@@ -80,31 +83,37 @@ void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
8083
const ggml_tensor * src1 = dst->src[1];
8184

8285
GGML_ASSERT(ggml_are_same_shape(src0, dst));
86+
87+
// TODO: support more dtypes.
8388
GGML_ASSERT(src0->type == GGML_TYPE_F32);
8489
GGML_ASSERT(src1->type == GGML_TYPE_F32);
8590
GGML_ASSERT(dst->type == GGML_TYPE_F32);
8691

87-
// dims
88-
const size_t ne0 = dst->ne[0];
89-
const size_t ne1 = dst->ne[1];
90-
const size_t ne2 = dst->ne[2];
91-
const size_t ne3 = dst->ne[3];
92+
GGML_TENSOR_BINARY_OP_LOCALS01;
93+
const int nb0 = ggml_element_size(dst);
9294

9395
const float * src0_d = (const float *) src0->data;
9496
const float * src1_d = (const float *) src1->data;
9597
float * dst_d = (float *) dst->data;
9698

9799
cudaStream_t stream = ctx.stream();
98100

99-
const size_t total = ne0 * ne1 * ne2 * ne3;
100-
const int threads = 256;
101-
const int blocks = (total + threads - 1) / threads;
102-
103101
if (!inplace) {
104-
// copy whole src1→dst
105-
set_f32_cuda_copy<<<blocks, threads, 0, stream>>>(src1_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3);
102+
// copy whole src0 -> dst.
103+
const size_t total = ne00 * ne01 * ne02 * ne03;
104+
105+
const int num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE;
106+
107+
set_f32_cuda_copy<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0, stream>>>(
108+
src0_d, dst_d, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03);
106109
}
107110

108-
// then overwrite from src0→dst at same offsets/strides
109-
set_f32_cuda<<<blocks, threads, 0, stream>>>(src0_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3);
111+
// set: src1 -> dst
112+
// set_f32_cuda
113+
114+
const size_t total = ne10 * ne11 * ne12 * ne13;
115+
const size_t num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE;
116+
117+
set_f32_cuda<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0, stream>>>(
118+
src1_d, dst_d, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, offset);
110119
}

ggml/src/ggml-cuda/set1.cu

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)