33
44using namespace sycl ;
55
6- // SSM_CONV kernel: State Space Model Convolution 1D
7- // This implements a sliding window convolution with history context
86static void kernel_ssm_conv (
97 queue &q,
10- const float *src_data, // input sequence [d_conv-1+n_t, d_inner, n_s]
11- const float *weights, // convolution weights [d_conv, d_inner]
12- float *dst_data, // output [d_inner, n_t, n_s]
13- int d_conv, // convolution window size
14- int d_inner, // number of inner channels
15- int n_t , // number of tokens to process
16- int n_s, // batch size (number of sequences)
17- int ncs __attribute__ ((unused)), // input sequence length (d_conv-1+n_t)
18- int src_stride_inner, // stride between channels in src
19- int src_stride_seq, // stride between sequences in src
20- int dst_stride_token, // stride between tokens in dst
21- int dst_stride_seq // stride between sequences in dst
8+ const float *src_data,
9+ const float *weights,
10+ float *dst_data,
11+ int d_conv,
12+ int d_inner,
13+ int n_t ,
14+ int n_s,
15+ int ncs __attribute__ ((unused)),
16+ int src_stride_inner,
17+ int src_stride_seq,
18+ int dst_stride_token,
19+ int dst_stride_seq
2220) {
23- // Each work item handles one (channel, token, sequence) combination
2421 const size_t total_work = d_inner * n_t * n_s;
2522 const size_t work_group_size = 256 ;
2623 const size_t num_work_groups = (total_work + work_group_size - 1 ) / work_group_size;
@@ -34,31 +31,18 @@ static void kernel_ssm_conv(
3431
3532 if (idx >= total_work) return ;
3633
37- // Decode indices: idx = seq * (d_inner * n_t) + token * d_inner + channel
3834 const int channel = idx % d_inner;
3935 const int token = (idx / d_inner) % n_t ;
4036 const int seq = idx / (d_inner * n_t );
4137
42- // Calculate input starting position for this token and channel
43- // Input layout: [d_conv-1+n_t, d_inner, n_s]
44- // Following CPU implementation: s[i0 + i1*ncs] where i0 is conv position, i1 is channel
45- // Note: s pointer is offset by token position for sliding window
4638 const float *s = src_data + seq * src_stride_seq + channel * src_stride_inner + token;
47-
48- // Get weights for this channel
49- // Weights layout: [d_conv, d_inner]
50- // Following CPU implementation: c[i0 + i1*nc] where i0 is conv position, i1 is channel
5139 const float *c = weights + channel * d_conv;
5240
53- // Perform dot product: sum(input_window * weights)
54- // Following CPU implementation exactly
5541 float sumf = 0 .0f ;
5642 for (int i0 = 0 ; i0 < d_conv; ++i0) {
57- sumf += s[i0] * c[i0]; // s[i0 + i1*ncs] * c[i0 + i1*nc]
43+ sumf += s[i0] * c[i0];
5844 }
5945
60- // Write result to output
61- // Output layout: [d_inner, n_t, n_s]
6246 const size_t dst_idx = seq * dst_stride_seq +
6347 token * dst_stride_token +
6448 channel;
@@ -68,41 +52,34 @@ static void kernel_ssm_conv(
6852}
6953
7054void ggml_sycl_ssm_conv (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
71- ggml_tensor * src0 = dst->src [0 ]; // conv_x: input sequence
72- ggml_tensor * src1 = dst->src [1 ]; // conv1d.weight: convolution weights
55+ ggml_tensor * src0 = dst->src [0 ];
56+ ggml_tensor * src1 = dst->src [1 ];
7357
7458 GGML_ASSERT (src0->type == GGML_TYPE_F32);
7559 GGML_ASSERT (src1->type == GGML_TYPE_F32);
7660 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
7761
78- // Extract dimensions following CPU implementation
79- const int d_conv = src1->ne [0 ]; // convolution window size
80- const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t (input sequence length)
81- const int d_inner = src0->ne [1 ]; // number of inner channels
82- const int n_t = dst->ne [1 ]; // number of tokens to process
83- const int n_s = dst->ne [2 ]; // batch size (number of sequences)
62+ const int d_conv = src1->ne [0 ];
63+ const int ncs = src0->ne [0 ];
64+ const int d_inner = src0->ne [1 ];
65+ const int n_t = dst->ne [1 ];
66+ const int n_s = dst->ne [2 ];
8467
85- // Verify dimensions match CPU implementation exactly
86- GGML_ASSERT (src0->ne [0 ] == d_conv - 1 + n_t ); // input length
87- GGML_ASSERT (src0->ne [1 ] == d_inner); // channels match
88- GGML_ASSERT (src1->ne [1 ] == d_inner); // weight channels match
89- GGML_ASSERT (dst->ne [0 ] == d_inner); // output channels
90- GGML_ASSERT (dst->ne [1 ] == n_t ); // output tokens
91- GGML_ASSERT (dst->ne [2 ] == n_s); // output sequences
68+ GGML_ASSERT (src0->ne [0 ] == d_conv - 1 + n_t );
69+ GGML_ASSERT (src0->ne [1 ] == d_inner);
70+ GGML_ASSERT (src1->ne [1 ] == d_inner);
71+ GGML_ASSERT (dst->ne [0 ] == d_inner);
72+ GGML_ASSERT (dst->ne [1 ] == n_t );
73+ GGML_ASSERT (dst->ne [2 ] == n_s);
9274
93- // Verify stride assumptions (from CPU implementation)
9475 GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
9576 GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
9677 GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
9778
98- // Calculate strides based on tensor layout (in elements, not bytes)
99- // src0: [d_conv-1+n_t, d_inner, n_s] - input sequence
100- const int src_stride_inner = ncs; // stride between channels in elements
101- const int src_stride_seq = ncs * d_inner; // stride between sequences in elements
102-
103- // dst: [d_inner, n_t, n_s] - output
104- const int dst_stride_token = d_inner; // stride between tokens in elements
105- const int dst_stride_seq = d_inner * n_t ; // stride between sequences in elements
79+ const int src_stride_inner = ncs;
80+ const int src_stride_seq = ncs * d_inner;
81+ const int dst_stride_token = d_inner;
82+ const int dst_stride_seq = d_inner * n_t ;
10683
10784 try {
10885 queue *q = ctx.stream ();
@@ -113,7 +90,6 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11390
11491 GGML_ASSERT (src_data && weights && dst_data);
11592
116- // Launch kernel
11793 kernel_ssm_conv (
11894 *q, src_data, weights, dst_data,
11995 d_conv, d_inner, n_t , n_s, ncs,
0 commit comments