Skip to content

Commit 75520d6

Browse files
committed
deduplicated functions
1 parent 949e4fa commit 75520d6

File tree

1 file changed

+60
-114
lines changed

1 file changed

+60
-114
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 60 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using namespace cub;
99

1010
#include "ssm-scan.cuh"
1111

12-
template <size_t splitD, size_t N>
12+
template <size_t splitD, size_t N, size_t L_template>
1313
__global__ void __launch_bounds__(splitD, 2)
1414
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
1515
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
@@ -18,7 +18,7 @@ __global__ void __launch_bounds__(splitD, 2)
1818
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
1919
float *__restrict__ dst, const int64_t L_param)
2020
{
21-
const size_t L = L_param;
21+
const size_t L = L_template == 0 ? L_param : L_template;
2222
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
2323
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
2424
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
@@ -62,6 +62,7 @@ __global__ void __launch_bounds__(splitD, 2)
6262
}
6363
#endif
6464

65+
#pragma unroll
6566
for (size_t i = 0; i < L; i++)
6667
{
6768
if (threadIdx.x < N)
@@ -101,89 +102,6 @@ __global__ void __launch_bounds__(splitD, 2)
101102
#endif
102103
}
103104

104-
template <size_t splitD, size_t N>
105-
__global__ void __launch_bounds__(splitD, 2)
106-
ssm_scan_single_step_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
107-
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
108-
const int src0_nb1, const int src0_nb2, const int src1_nb2,
109-
const int src1_nb3, const int src2_nb2, const int src3_nb1,
110-
const int src4_nb2, const int src5_nb2,
111-
float *__restrict__ dst)
112-
{
113-
const float *s0_block = (const float *)((const char *)src0 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
114-
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
115-
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
116-
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
117-
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb2));
118-
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb2));
119-
float *y_block = (float *)((char *)dst + (blockIdx.x * src1_nb2) + blockIdx.y * splitD * sizeof(float));
120-
float *s_block = (float *)((char *)dst + src1_nb3 + blockIdx.x * src0_nb2 + blockIdx.y * splitD * src0_nb1);
121-
122-
float regA[N];
123-
float regs0[N];
124-
125-
__shared__ float smemB[N];
126-
__shared__ float smemC[N];
127-
128-
#ifdef USE_CUB
129-
using BlockLoadA = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
130-
using BlockLoadS0 = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_VECTORIZE>;
131-
using BlockStoreS = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_VECTORIZE>;
132-
133-
__shared__ typename BlockLoadA::TempStorage block_load_tempA;
134-
__shared__ typename BlockLoadS0::TempStorage block_load_tempS0;
135-
__shared__ typename BlockStoreS::TempStorage block_store_tempS;
136-
137-
BlockLoadA(block_load_tempA).Load(A_block, regA);
138-
BlockLoadS0(block_load_tempS0).Load(s0_block, regs0);
139-
#else
140-
const int stride_s0 = src0_nb1 / sizeof(float);
141-
const int stride_A = src3_nb1 / sizeof(float);
142-
#pragma unroll
143-
for (size_t n = 0; n < N; ++n)
144-
{
145-
regA[n] = A_block[threadIdx.x * stride_A + n];
146-
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
147-
}
148-
#endif
149-
150-
if (threadIdx.x < N)
151-
{
152-
smemB[threadIdx.x] = B_block[threadIdx.x];
153-
smemC[threadIdx.x] = C_block[threadIdx.x];
154-
}
155-
__syncthreads();
156-
157-
{
158-
float dt_soft_plus = dt_block[threadIdx.x];
159-
if (dt_soft_plus <= 20.0f)
160-
{
161-
dt_soft_plus = log1pf(expf(dt_soft_plus));
162-
}
163-
float x_dt = x_block[threadIdx.x] * dt_soft_plus;
164-
float sumf = 0.0f;
165-
#pragma unroll
166-
for (size_t n = 0; n < N; n++)
167-
{
168-
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
169-
sumf += state * smemC[n];
170-
regs0[n] = state;
171-
}
172-
y_block[threadIdx.x] = sumf;
173-
}
174-
175-
#ifdef USE_CUB
176-
BlockStoreS(block_store_tempS).Store(s_block, regs0);
177-
#else
178-
const int stride_s = stride_s0;
179-
#pragma unroll
180-
for (size_t n = 0; n < N; ++n)
181-
{
182-
s_block[threadIdx.x * stride_s + n] = regs0[n];
183-
}
184-
#endif
185-
}
186-
187105
static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float *src2, const float *src3,
188106
const float *src4, const float *src5, const int src0_nb1, const int src0_nb2,
189107
const int src1_nb1, const int src1_nb2, const int src1_nb3,
@@ -198,19 +116,46 @@ static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float
198116
const dim3 blocks(B, (D + threads - 1) / threads, 1);
199117
if (N == 16)
200118
{
201-
if (L > 1)
119+
switch (L)
202120
{
203-
ssm_scan_f32<threads, 16><<<blocks, threads, 0, stream>>>(
121+
case 1:
122+
ssm_scan_f32<threads, 16, 1><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
123+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
124+
break;
125+
case 2:
126+
ssm_scan_f32<threads, 16, 2><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
127+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
128+
break;
129+
case 3:
130+
ssm_scan_f32<threads, 16, 3><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
131+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
132+
break;
133+
case 4:
134+
ssm_scan_f32<threads, 16, 4><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
135+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
136+
break;
137+
case 5:
138+
ssm_scan_f32<threads, 16, 5><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
139+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
140+
break;
141+
case 6:
142+
ssm_scan_f32<threads, 16, 6><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
143+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
144+
break;
145+
case 7:
146+
ssm_scan_f32<threads, 16, 7><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
147+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
148+
break;
149+
case 8:
150+
ssm_scan_f32<threads, 16, 8><<<blocks, threads, 0, stream>>>(src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
151+
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
152+
break;
153+
154+
default:
155+
ssm_scan_f32<threads, 16, 0><<<blocks, threads, 0, stream>>>(
204156
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb1, src1_nb2, src1_nb3,
205157
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
206-
}
207-
else
208-
{
209-
ssm_scan_single_step_f32<threads, 16><<<blocks, threads, 0, stream>>>(
210-
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb2,
211-
src1_nb3, src2_nb2, src3_nb1,
212-
src4_nb2, src5_nb2,
213-
dst);
158+
break;
214159
}
215160
}
216161
else
@@ -219,23 +164,24 @@ static void ssm_scan_f32_cuda(const float *src0, const float *src1, const float
219164
}
220165
}
221166

222-
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223-
const struct ggml_tensor * src0 = dst->src[0]; // s
224-
const struct ggml_tensor * src1 = dst->src[1]; // x
225-
const struct ggml_tensor * src2 = dst->src[2]; // dt
226-
const struct ggml_tensor * src3 = dst->src[3]; // A
227-
const struct ggml_tensor * src4 = dst->src[4]; // B
228-
const struct ggml_tensor * src5 = dst->src[5]; // C
167+
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context &ctx, ggml_tensor *dst)
168+
{
169+
const struct ggml_tensor *src0 = dst->src[0]; // s
170+
const struct ggml_tensor *src1 = dst->src[1]; // x
171+
const struct ggml_tensor *src2 = dst->src[2]; // dt
172+
const struct ggml_tensor *src3 = dst->src[3]; // A
173+
const struct ggml_tensor *src4 = dst->src[4]; // B
174+
const struct ggml_tensor *src5 = dst->src[5]; // C
229175

230176
// const int64_t d_state = src0->ne[0];
231177
// const int64_t d_inner = src0->ne[1];
232178
// const int64_t l = src1->ne[1];
233179
// const int64_t b = src0->ne[2];
234180

235-
const int64_t nc = src0->ne[0]; // d_state
236-
const int64_t nr = src0->ne[1]; // d_inner
237-
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
238-
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
181+
const int64_t nc = src0->ne[0]; // d_state
182+
const int64_t nr = src0->ne[1]; // d_inner
183+
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
184+
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
239185

240186
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
241187
GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -251,14 +197,14 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
251197
// required to get correct offset for state destination (i.e. src1->nb[3])
252198
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
253199

254-
const float * src0_d = (const float *) src0->data;
255-
const float * src1_d = (const float *) src1->data;
256-
const float * src2_d = (const float *) src2->data;
257-
const float * src3_d = (const float *) src3->data;
258-
const float * src4_d = (const float *) src4->data;
259-
const float * src5_d = (const float *) src5->data;
260-
float * dst_d = (float *) dst->data;
261-
cudaStream_t stream = ctx.stream();
200+
const float *src0_d = (const float *)src0->data;
201+
const float *src1_d = (const float *)src1->data;
202+
const float *src2_d = (const float *)src2->data;
203+
const float *src3_d = (const float *)src3->data;
204+
const float *src4_d = (const float *)src4->data;
205+
const float *src5_d = (const float *)src5->data;
206+
float *dst_d = (float *)dst->data;
207+
cudaStream_t stream = ctx.stream();
262208

263209
GGML_ASSERT(src0->type == GGML_TYPE_F32);
264210
GGML_ASSERT(dst->type == GGML_TYPE_F32);

0 commit comments

Comments
 (0)