@@ -9,7 +9,7 @@ using namespace cub;
99
1010#include " ssm-scan.cuh"
1111
12- template <size_t splitD, size_t N>
12+ template <size_t splitD, size_t N, size_t L_template >
1313__global__ void __launch_bounds__ (splitD, 2 )
1414 ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
1515 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
@@ -18,7 +18,7 @@ __global__ void __launch_bounds__(splitD, 2)
1818 const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
1919 float *__restrict__ dst, const int64_t L_param)
2020{
21- const size_t L = L_param;
21+ const size_t L = L_template == 0 ? L_param : L_template ;
2222 const float *s0_block = (const float *)((const char *)src0 + blockIdx .x * src0_nb2 + blockIdx .y * splitD * src0_nb1);
2323 const float *x_block = (const float *)((const char *)src1 + (blockIdx .x * src1_nb2) + blockIdx .y * splitD * sizeof (float ));
2424 const float *dt_block = (const float *)((const char *)src2 + (blockIdx .x * src2_nb2) + blockIdx .y * splitD * sizeof (float ));
@@ -62,6 +62,7 @@ __global__ void __launch_bounds__(splitD, 2)
6262 }
6363#endif
6464
65+ #pragma unroll
6566 for (size_t i = 0 ; i < L; i++)
6667 {
6768 if (threadIdx .x < N)
@@ -101,89 +102,6 @@ __global__ void __launch_bounds__(splitD, 2)
101102#endif
102103}
103104
104- template <size_t splitD, size_t N>
105- __global__ void __launch_bounds__ (splitD, 2 )
106- ssm_scan_single_step_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
107- const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
108- const int src0_nb1, const int src0_nb2, const int src1_nb2,
109- const int src1_nb3, const int src2_nb2, const int src3_nb1,
110- const int src4_nb2, const int src5_nb2,
111- float *__restrict__ dst)
112- {
113- const float *s0_block = (const float *)((const char *)src0 + blockIdx .x * src0_nb2 + blockIdx .y * splitD * src0_nb1);
114- const float *x_block = (const float *)((const char *)src1 + (blockIdx .x * src1_nb2) + blockIdx .y * splitD * sizeof (float ));
115- const float *dt_block = (const float *)((const char *)src2 + (blockIdx .x * src2_nb2) + blockIdx .y * splitD * sizeof (float ));
116- const float *A_block = (const float *)((const char *)src3 + blockIdx .y * splitD * src3_nb1);
117- const float *B_block = (const float *)((const char *)src4 + (blockIdx .x * src4_nb2));
118- const float *C_block = (const float *)((const char *)src5 + (blockIdx .x * src5_nb2));
119- float *y_block = (float *)((char *)dst + (blockIdx .x * src1_nb2) + blockIdx .y * splitD * sizeof (float ));
120- float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx .x * src0_nb2 + blockIdx .y * splitD * src0_nb1);
121-
122- float regA[N];
123- float regs0[N];
124-
125- __shared__ float smemB[N];
126- __shared__ float smemC[N];
127-
128- #ifdef USE_CUB
129- using BlockLoadA = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
130- using BlockLoadS0 = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
131- using BlockStoreS = cub::BlockStore<float , splitD, N, cub::BLOCK_STORE_VECTORIZE>;
132-
133- __shared__ typename BlockLoadA::TempStorage block_load_tempA;
134- __shared__ typename BlockLoadS0::TempStorage block_load_tempS0;
135- __shared__ typename BlockStoreS::TempStorage block_store_tempS;
136-
137- BlockLoadA (block_load_tempA).Load (A_block, regA);
138- BlockLoadS0 (block_load_tempS0).Load (s0_block, regs0);
139- #else
140- const int stride_s0 = src0_nb1 / sizeof (float );
141- const int stride_A = src3_nb1 / sizeof (float );
142- #pragma unroll
143- for (size_t n = 0 ; n < N; ++n)
144- {
145- regA[n] = A_block[threadIdx .x * stride_A + n];
146- regs0[n] = s0_block[threadIdx .x * stride_s0 + n];
147- }
148- #endif
149-
150- if (threadIdx .x < N)
151- {
152- smemB[threadIdx .x ] = B_block[threadIdx .x ];
153- smemC[threadIdx .x ] = C_block[threadIdx .x ];
154- }
155- __syncthreads ();
156-
157- {
158- float dt_soft_plus = dt_block[threadIdx .x ];
159- if (dt_soft_plus <= 20 .0f )
160- {
161- dt_soft_plus = log1pf (expf (dt_soft_plus));
162- }
163- float x_dt = x_block[threadIdx .x ] * dt_soft_plus;
164- float sumf = 0 .0f ;
165- #pragma unroll
166- for (size_t n = 0 ; n < N; n++)
167- {
168- float state = regs0[n] * expf (dt_soft_plus * regA[n]) + smemB[n] * x_dt;
169- sumf += state * smemC[n];
170- regs0[n] = state;
171- }
172- y_block[threadIdx .x ] = sumf;
173- }
174-
175- #ifdef USE_CUB
176- BlockStoreS (block_store_tempS).Store (s_block, regs0);
177- #else
178- const int stride_s = stride_s0;
179- #pragma unroll
180- for (size_t n = 0 ; n < N; ++n)
181- {
182- s_block[threadIdx .x * stride_s + n] = regs0[n];
183- }
184- #endif
185- }
186-
187105static void ssm_scan_f32_cuda (const float *src0, const float *src1, const float *src2, const float *src3,
188106 const float *src4, const float *src5, const int src0_nb1, const int src0_nb2,
189107 const int src1_nb1, const int src1_nb2, const int src1_nb3,
@@ -198,19 +116,46 @@ static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float
198116 const dim3 blocks (B, (D + threads - 1 ) / threads, 1 );
199117 if (N == 16 )
200118 {
201- if (L > 1 )
119+ switch (L)
202120 {
203- ssm_scan_f32<threads, 16 ><<<blocks, threads, 0 , stream>>> (
121+ case 1 :
122+ ssm_scan_f32<threads, 16 , 1 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
123+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
124+ break ;
125+ case 2 :
126+ ssm_scan_f32<threads, 16 , 2 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
127+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
128+ break ;
129+ case 3 :
130+ ssm_scan_f32<threads, 16 , 3 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
131+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
132+ break ;
133+ case 4 :
134+ ssm_scan_f32<threads, 16 , 4 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
135+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
136+ break ;
137+ case 5 :
138+ ssm_scan_f32<threads, 16 , 5 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
139+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
140+ break ;
141+ case 6 :
142+ ssm_scan_f32<threads, 16 , 6 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
143+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
144+ break ;
145+ case 7 :
146+ ssm_scan_f32<threads, 16 , 7 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
147+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
148+ break ;
149+ case 8 :
150+ ssm_scan_f32<threads, 16 , 8 ><<<blocks, threads, 0 , stream>>> (src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
151+ src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
152+ break ;
153+
154+ default :
155+ ssm_scan_f32<threads, 16 , 0 ><<<blocks, threads, 0 , stream>>> (
204156 src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
205157 src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
206- }
207- else
208- {
209- ssm_scan_single_step_f32<threads, 16 ><<<blocks, threads, 0 , stream>>> (
210- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb2,
211- src1_nb3, src2_nb2, src3_nb1,
212- src4_nb2, src5_nb2,
213- dst);
158+ break ;
214159 }
215160 }
216161 else
@@ -219,23 +164,24 @@ static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float
219164 }
220165}
221166
222- void ggml_cuda_op_ssm_scan (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223- const struct ggml_tensor * src0 = dst->src [0 ]; // s
224- const struct ggml_tensor * src1 = dst->src [1 ]; // x
225- const struct ggml_tensor * src2 = dst->src [2 ]; // dt
226- const struct ggml_tensor * src3 = dst->src [3 ]; // A
227- const struct ggml_tensor * src4 = dst->src [4 ]; // B
228- const struct ggml_tensor * src5 = dst->src [5 ]; // C
167+ void ggml_cuda_op_ssm_scan (ggml_backend_cuda_context &ctx, ggml_tensor *dst)
168+ {
169+ const struct ggml_tensor *src0 = dst->src [0 ]; // s
170+ const struct ggml_tensor *src1 = dst->src [1 ]; // x
171+ const struct ggml_tensor *src2 = dst->src [2 ]; // dt
172+ const struct ggml_tensor *src3 = dst->src [3 ]; // A
173+ const struct ggml_tensor *src4 = dst->src [4 ]; // B
174+ const struct ggml_tensor *src5 = dst->src [5 ]; // C
229175
230176 // const int64_t d_state = src0->ne[0];
231177 // const int64_t d_inner = src0->ne[1];
232178 // const int64_t l = src1->ne[1];
233179 // const int64_t b = src0->ne[2];
234180
235- const int64_t nc = src0->ne [0 ]; // d_state
236- const int64_t nr = src0->ne [1 ]; // d_inner
237- const int64_t n_t = src1->ne [1 ]; // number of tokens per sequence
238- const int64_t n_s = src0->ne [2 ]; // number of sequences in the batch
181+ const int64_t nc = src0->ne [0 ]; // d_state
182+ const int64_t nr = src0->ne [1 ]; // d_inner
183+ const int64_t n_t = src1->ne [1 ]; // number of tokens per sequence
184+ const int64_t n_s = src0->ne [2 ]; // number of sequences in the batch
239185
240186 GGML_ASSERT (ggml_nelements (src1) + ggml_nelements (src0) == ggml_nelements (dst));
241187 GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
@@ -251,14 +197,14 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
251197 // required to get correct offset for state destination (i.e. src1->nb[3])
252198 GGML_ASSERT (src1->nb [3 ] == src1->ne [0 ] * src1->ne [1 ] * src1->ne [2 ] * sizeof (float ));
253199
254- const float * src0_d = (const float *) src0->data ;
255- const float * src1_d = (const float *) src1->data ;
256- const float * src2_d = (const float *) src2->data ;
257- const float * src3_d = (const float *) src3->data ;
258- const float * src4_d = (const float *) src4->data ;
259- const float * src5_d = (const float *) src5->data ;
260- float * dst_d = (float *) dst->data ;
261- cudaStream_t stream = ctx.stream ();
200+ const float *src0_d = (const float *)src0->data ;
201+ const float *src1_d = (const float *)src1->data ;
202+ const float *src2_d = (const float *)src2->data ;
203+ const float *src3_d = (const float *)src3->data ;
204+ const float *src4_d = (const float *)src4->data ;
205+ const float *src5_d = (const float *)src5->data ;
206+ float *dst_d = (float *)dst->data ;
207+ cudaStream_t stream = ctx.stream ();
262208
263209 GGML_ASSERT (src0->type == GGML_TYPE_F32);
264210 GGML_ASSERT (dst->type == GGML_TYPE_F32);
0 commit comments