Skip to content

Commit 50dfd09

Browse files
committed
CUDA/HIP: ssm-scan: switch from shared memory to reisters, fixes indexing problem on warp64 devices
1 parent fd1234c commit 50dfd09

File tree

1 file changed

+18
-39
lines changed

1 file changed

+18
-39
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,9 @@ __global__ void __launch_bounds__(splitD, 2)
99
const int src2_nb1, const int src2_nb2, const int src3_nb1,
1010
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
1111
const int64_t s_off, const int64_t d_inner, const int64_t L) {
12-
13-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1412
const int bidx = blockIdx.x; // split along B (sequences)
1513
const int bidy = blockIdx.y; // split along D (d_inner)
1614
const int tid = threadIdx.x;
17-
const int wid = tid / 32;
18-
const int wtid = tid % 32;
19-
20-
extern __shared__ float smem[];
21-
const int stride_sA = N + 1;
22-
const int stride_ss0 = N + 1;
23-
float * smem_A = smem;
24-
float * smem_s0 = smem_A + splitD * stride_sA;
2515

2616
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
2717
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
@@ -41,46 +31,35 @@ __global__ void __launch_bounds__(splitD, 2)
4131
const int stride_s = stride_s0;
4232
const int stride_y = d_inner;
4333

44-
// can N not be 16? for example 32?
45-
if (N == 16) {
46-
#pragma unroll
47-
for (size_t i = 0; i < splitD / 4; i += 2) {
48-
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
49-
// todo: bank conflict
50-
// I am always confused with how to use the swizzling method to solve
51-
// bank conflit. Hoping somebody can tell me.
52-
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53-
}
34+
float A[N];
35+
float s[N];
36+
5437
#pragma unroll
55-
for (size_t i = 0; i < splitD / 4; i += 2) {
56-
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
57-
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
58-
}
38+
for (int j = 0; j < N; j++) {
39+
A[j] = A_block[tid * N + j];
40+
s[j] = s0_block[tid * stride_s0 + j];
5941
}
6042

61-
__syncthreads();
62-
6343
for (int64_t i = 0; i < L; i++) {
6444
float dt_soft_plus = dt_block[i * stride_dt + tid];
65-
if (dt_soft_plus <= 20.0f) {
66-
dt_soft_plus = log1pf(exp(dt_soft_plus));
67-
}
68-
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
45+
dt_soft_plus = (dt_soft_plus > 20.0f) ? dt_soft_plus : log1pf(expf(dt_soft_plus));
46+
47+
const float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
6948
float sumf = 0.0f;
49+
7050
#pragma unroll
7151
for (size_t j = 0; j < N; j++) {
72-
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
73-
(B_block[i * stride_B + j] * x_dt);
74-
sumf += state * C_block[i * stride_C + j];
75-
if (i == L - 1) {
76-
s_block[tid * stride_s + j] = state;
77-
} else {
78-
smem_s0[tid * stride_ss0 + j] = state;
79-
}
52+
const float exp_term = expf(dt_soft_plus * A[j]);
53+
s[j] = fmaf(s[j], exp_term, B_block[i * stride_B + j] * x_dt);
54+
sumf = fmaf(s[j], C_block[i * stride_C + j], sumf);
8055
}
81-
__syncthreads();
8256
y_block[i * stride_y + tid] = sumf;
8357
}
58+
59+
#pragma unroll
60+
for (int j = 0; j < N; j++) {
61+
s_block[tid * stride_s + j] = s[j];
62+
}
8463
}
8564

8665
// assumes as many threads as d_state

0 commit comments

Comments
 (0)