Skip to content

Commit 1e64567

Browse files
committed
clang format
1 parent 6a6c954 commit 1e64567

File tree

4 files changed

+198
-231
lines changed

4 files changed

+198
-231
lines changed

ggml/src/ggml-cuda/ssm_conv.cu

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,82 @@
11
#include "ssm_conv.cuh"
22

33
template <int block_size>
4-
static __global__ void ssm_conv_f32(const float *__restrict__ src0,
5-
const float *__restrict__ src1,
6-
const int src0_nb0, const int src0_nb1,
7-
const int src0_nb2, const int src1_nb1,
8-
float *__restrict__ dst, const int dst_nb0,
9-
const int dst_nb1, const int dst_nb2,
10-
const int nc, const int ncs, const int nr,
11-
const int n_t, const int n_s) {
12-
const int tid = blockIdx.y;
13-
const int i3 = blockIdx.x;
14-
const int i2 = threadIdx.x;
4+
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
5+
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
6+
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7+
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
8+
const int tid = blockIdx.y;
9+
const int i3 = blockIdx.x;
10+
const int i2 = threadIdx.x;
1511

16-
const int ith = tid;
17-
const int nth = WARP_SIZE;
12+
const int ith = tid;
13+
const int nth = WARP_SIZE;
1814

19-
// rows per thread
20-
const int dr = (nr + nth - 1) / nth;
15+
// rows per thread
16+
const int dr = (nr + nth - 1) / nth;
2117

22-
// row range for this thread
23-
const int ir0 = dr * ith;
24-
const int ir1 = min(ir0 + dr, nr);
25-
const int ir = ir1 - ir0;
18+
// row range for this thread
19+
const int ir0 = dr * ith;
20+
const int ir1 = min(ir0 + dr, nr);
21+
const int ir = ir1 - ir0;
2622

27-
// {d_conv - 1 + n_t, d_inner, n_seqs}
28-
// sliding window
29-
const float *s =
30-
(const float *)((const char *)src0 + ir0 * src0_nb1 + i2 * src0_nb0 +
31-
i3 * src0_nb2); // {d_conv, d_inner, n_s}
32-
const float *c = (const float *)((const char *)src1 +
33-
ir0 * src1_nb1); // {d_conv, d_inner}
34-
float *x = (float *)((char *)dst + ir0 * dst_nb0 + i2 * dst_nb1 +
35-
i3 * dst_nb2); // {d_inner, n_t, n_s}
23+
// {d_conv - 1 + n_t, d_inner, n_seqs}
24+
// sliding window
25+
const float * s = (const float *) ((const char *) src0 + ir0 * src0_nb1 + i2 * src0_nb0 +
26+
i3 * src0_nb2); // {d_conv, d_inner, n_s}
27+
const float * c = (const float *) ((const char *) src1 + ir0 * src1_nb1); // {d_conv, d_inner}
28+
float * x = (float *) ((char *) dst + ir0 * dst_nb0 + i2 * dst_nb1 + i3 * dst_nb2); // {d_inner, n_t, n_s}
3629

37-
// TODO: transpose the output for smaller strides for big batches?
38-
// d_inner
39-
for (int i1 = 0; i1 < ir; ++i1) {
40-
// rowwise dot product
41-
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
42-
float sumf = 0.0f;
30+
// TODO: transpose the output for smaller strides for big batches?
31+
// d_inner
32+
for (int i1 = 0; i1 < ir; ++i1) {
33+
// rowwise dot product
34+
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
35+
float sumf = 0.0f;
4336

4437
// d_conv
4538
#pragma unroll
46-
for (int i0 = 0; i0 < nc; ++i0) {
47-
sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc];
39+
for (int i0 = 0; i0 < nc; ++i0) {
40+
sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc];
41+
}
42+
x[i1] = sumf;
4843
}
49-
x[i1] = sumf;
50-
}
5144
}
5245

53-
static void ssm_conv_f32_cuda(const float *src0, const float *src1,
54-
const int src0_nb0, const int src0_nb1,
55-
const int src0_nb2, const int src1_nb1,
56-
float *dst, const int dst_nb0, const int dst_nb1,
57-
const int dst_nb2, const int nc, const int ncs,
58-
const int nr, const int n_t, const int n_s,
59-
cudaStream_t stream) {
60-
const dim3 block_dims(n_t, 1, 1);
61-
// const int nblocks = n_s; // TODO
62-
const dim3 grid_dims(n_s, WARP_SIZE, 1);
46+
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
47+
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
48+
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
49+
const int n_s, cudaStream_t stream) {
50+
const dim3 block_dims(n_t, 1, 1);
51+
// const int nblocks = n_s; // TODO
52+
const dim3 grid_dims(n_s, WARP_SIZE, 1);
6353

64-
ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>(
65-
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1,
66-
dst_nb2, nc, ncs, nr, n_t, n_s);
54+
ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>(
55+
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
6756
}
6857

69-
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context &ctx, ggml_tensor *dst) {
70-
const struct ggml_tensor *src0 = dst->src[0]; // conv_x
71-
const struct ggml_tensor *src1 = dst->src[1]; // conv1d.weight
58+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
59+
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
60+
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
7261

73-
const int nc = src1->ne[0]; // d_conv
74-
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
75-
const int nr = src0->ne[1]; // d_inner
76-
const int n_t = dst->ne[1]; // tokens per sequence
77-
const int n_s = dst->ne[2]; // number of sequences in the batch
62+
const int nc = src1->ne[0]; // d_conv
63+
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
64+
const int nr = src0->ne[1]; // d_inner
65+
const int n_t = dst->ne[1]; // tokens per sequence
66+
const int n_s = dst->ne[2]; // number of sequences in the batch
7867

79-
GGML_ASSERT(dst->ne[0] == nr);
80-
GGML_ASSERT(src0->nb[0] == sizeof(float));
81-
GGML_ASSERT(src1->nb[0] == sizeof(float));
82-
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
68+
GGML_ASSERT(dst->ne[0] == nr);
69+
GGML_ASSERT(src0->nb[0] == sizeof(float));
70+
GGML_ASSERT(src1->nb[0] == sizeof(float));
71+
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
8372

84-
const float *src0_d = (const float *)src0->data;
85-
const float *src1_d = (const float *)src1->data;
86-
float *dst_d = (float *)dst->data;
87-
cudaStream_t stream = ctx.stream();
73+
const float * src0_d = (const float *) src0->data;
74+
const float * src1_d = (const float *) src1->data;
75+
float * dst_d = (float *) dst->data;
76+
cudaStream_t stream = ctx.stream();
8877

89-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
90-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
91-
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2],
92-
src1->nb[1], dst_d, dst->nb[0], dst->nb[1], dst->nb[2], nc,
93-
ncs, nr, n_t, n_s, stream);
94-
}
78+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
79+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
80+
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
81+
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
82+
}

ggml/src/ggml-cuda/ssm_conv.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#include "common.cuh"
22

3-
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context& ctx, ggml_tensor* dst);
3+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)