Skip to content

Commit 1e5148f

Browse files
tamarPaltamarPal
authored andcommitted
Clean SSM_CONV code - remove all comments for production
Removed all inline comments and documentation from the implementation. Clean, minimal code ready for production merge.
1 parent f3c0ac9 commit 1e5148f

File tree

1 file changed

+30
-54
lines changed

1 file changed

+30
-54
lines changed

ggml/src/ggml-sycl/ssm_conv.cpp

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,21 @@
33

44
using namespace sycl;
55

6-
// SSM_CONV kernel: State Space Model Convolution 1D
7-
// This implements a sliding window convolution with history context
86
static void kernel_ssm_conv(
97
queue &q,
10-
const float *src_data, // input sequence [d_conv-1+n_t, d_inner, n_s]
11-
const float *weights, // convolution weights [d_conv, d_inner]
12-
float *dst_data, // output [d_inner, n_t, n_s]
13-
int d_conv, // convolution window size
14-
int d_inner, // number of inner channels
15-
int n_t, // number of tokens to process
16-
int n_s, // batch size (number of sequences)
17-
int ncs __attribute__((unused)), // input sequence length (d_conv-1+n_t)
18-
int src_stride_inner, // stride between channels in src
19-
int src_stride_seq, // stride between sequences in src
20-
int dst_stride_token, // stride between tokens in dst
21-
int dst_stride_seq // stride between sequences in dst
8+
const float *src_data,
9+
const float *weights,
10+
float *dst_data,
11+
int d_conv,
12+
int d_inner,
13+
int n_t,
14+
int n_s,
15+
int ncs __attribute__((unused)),
16+
int src_stride_inner,
17+
int src_stride_seq,
18+
int dst_stride_token,
19+
int dst_stride_seq
2220
) {
23-
// Each work item handles one (channel, token, sequence) combination
2421
const size_t total_work = d_inner * n_t * n_s;
2522
const size_t work_group_size = 256;
2623
const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
@@ -34,31 +31,18 @@ static void kernel_ssm_conv(
3431

3532
if (idx >= total_work) return;
3633

37-
// Decode indices: idx = seq * (d_inner * n_t) + token * d_inner + channel
3834
const int channel = idx % d_inner;
3935
const int token = (idx / d_inner) % n_t;
4036
const int seq = idx / (d_inner * n_t);
4137

42-
// Calculate input starting position for this token and channel
43-
// Input layout: [d_conv-1+n_t, d_inner, n_s]
44-
// Following CPU implementation: s[i0 + i1*ncs] where i0 is conv position, i1 is channel
45-
// Note: s pointer is offset by token position for sliding window
4638
const float *s = src_data + seq * src_stride_seq + channel * src_stride_inner + token;
47-
48-
// Get weights for this channel
49-
// Weights layout: [d_conv, d_inner]
50-
// Following CPU implementation: c[i0 + i1*nc] where i0 is conv position, i1 is channel
5139
const float *c = weights + channel * d_conv;
5240

53-
// Perform dot product: sum(input_window * weights)
54-
// Following CPU implementation exactly
5541
float sumf = 0.0f;
5642
for (int i0 = 0; i0 < d_conv; ++i0) {
57-
sumf += s[i0] * c[i0]; // s[i0 + i1*ncs] * c[i0 + i1*nc]
43+
sumf += s[i0] * c[i0];
5844
}
5945

60-
// Write result to output
61-
// Output layout: [d_inner, n_t, n_s]
6246
const size_t dst_idx = seq * dst_stride_seq +
6347
token * dst_stride_token +
6448
channel;
@@ -68,41 +52,34 @@ static void kernel_ssm_conv(
6852
}
6953

7054
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
71-
ggml_tensor * src0 = dst->src[0]; // conv_x: input sequence
72-
ggml_tensor * src1 = dst->src[1]; // conv1d.weight: convolution weights
55+
ggml_tensor * src0 = dst->src[0];
56+
ggml_tensor * src1 = dst->src[1];
7357

7458
GGML_ASSERT(src0->type == GGML_TYPE_F32);
7559
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7660
GGML_ASSERT( dst->type == GGML_TYPE_F32);
7761

78-
// Extract dimensions following CPU implementation
79-
const int d_conv = src1->ne[0]; // convolution window size
80-
const int ncs = src0->ne[0]; // d_conv - 1 + n_t (input sequence length)
81-
const int d_inner = src0->ne[1]; // number of inner channels
82-
const int n_t = dst->ne[1]; // number of tokens to process
83-
const int n_s = dst->ne[2]; // batch size (number of sequences)
62+
const int d_conv = src1->ne[0];
63+
const int ncs = src0->ne[0];
64+
const int d_inner = src0->ne[1];
65+
const int n_t = dst->ne[1];
66+
const int n_s = dst->ne[2];
8467

85-
// Verify dimensions match CPU implementation exactly
86-
GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t); // input length
87-
GGML_ASSERT(src0->ne[1] == d_inner); // channels match
88-
GGML_ASSERT(src1->ne[1] == d_inner); // weight channels match
89-
GGML_ASSERT(dst->ne[0] == d_inner); // output channels
90-
GGML_ASSERT(dst->ne[1] == n_t); // output tokens
91-
GGML_ASSERT(dst->ne[2] == n_s); // output sequences
68+
GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
69+
GGML_ASSERT(src0->ne[1] == d_inner);
70+
GGML_ASSERT(src1->ne[1] == d_inner);
71+
GGML_ASSERT(dst->ne[0] == d_inner);
72+
GGML_ASSERT(dst->ne[1] == n_t);
73+
GGML_ASSERT(dst->ne[2] == n_s);
9274

93-
// Verify stride assumptions (from CPU implementation)
9475
GGML_ASSERT(src0->nb[0] == sizeof(float));
9576
GGML_ASSERT(src1->nb[0] == sizeof(float));
9677
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
9778

98-
// Calculate strides based on tensor layout (in elements, not bytes)
99-
// src0: [d_conv-1+n_t, d_inner, n_s] - input sequence
100-
const int src_stride_inner = ncs; // stride between channels in elements
101-
const int src_stride_seq = ncs * d_inner; // stride between sequences in elements
102-
103-
// dst: [d_inner, n_t, n_s] - output
104-
const int dst_stride_token = d_inner; // stride between tokens in elements
105-
const int dst_stride_seq = d_inner * n_t; // stride between sequences in elements
79+
const int src_stride_inner = ncs;
80+
const int src_stride_seq = ncs * d_inner;
81+
const int dst_stride_token = d_inner;
82+
const int dst_stride_seq = d_inner * n_t;
10683

10784
try {
10885
queue *q = ctx.stream();
@@ -113,7 +90,6 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11390

11491
GGML_ASSERT(src_data && weights && dst_data);
11592

116-
// Launch kernel
11793
kernel_ssm_conv(
11894
*q, src_data, weights, dst_data,
11995
d_conv, d_inner, n_t, n_s, ncs,

0 commit comments

Comments
 (0)