@@ -9,19 +9,9 @@ __global__ void __launch_bounds__(splitD, 2)
99 const int src2_nb1, const int src2_nb2, const int src3_nb1,
1010 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
1111 const int64_t s_off, const int64_t d_inner, const int64_t L) {
12-
13- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
1412 const int bidx = blockIdx .x ; // split along B (sequences)
1513 const int bidy = blockIdx .y ; // split along D (d_inner)
1614 const int tid = threadIdx .x ;
17- const int wid = tid / 32 ;
18- const int wtid = tid % 32 ;
19-
20- extern __shared__ float smem[];
21- const int stride_sA = N + 1 ;
22- const int stride_ss0 = N + 1 ;
23- float * smem_A = smem;
24- float * smem_s0 = smem_A + splitD * stride_sA;
2515
2616 const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
2717 const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof (float ));
@@ -41,46 +31,35 @@ __global__ void __launch_bounds__(splitD, 2)
4131 const int stride_s = stride_s0;
4232 const int stride_y = d_inner;
4333
44- // can N not be 16? for example 32?
45- if (N == 16 ) {
46- #pragma unroll
47- for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
48- float value = A_block[(wid * warp_size + i) * stride_A + wtid];
49- // todo: bank conflict
50- // I am always confused with how to use the swizzling method to solve
51- // bank conflit. Hoping somebody can tell me.
52- smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
53- }
34+ float A[N];
35+ float s[N];
36+
5437#pragma unroll
55- for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
56- float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
57- smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
58- }
38+ for (int j = 0 ; j < N; j++) {
39+ A[j] = A_block[tid * N + j];
40+ s[j] = s0_block[tid * stride_s0 + j];
5941 }
6042
61- __syncthreads ();
62-
6343 for (int64_t i = 0 ; i < L; i++) {
6444 float dt_soft_plus = dt_block[i * stride_dt + tid];
65- if (dt_soft_plus <= 20 .0f ) {
66- dt_soft_plus = log1pf (exp (dt_soft_plus));
67- }
68- float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
45+ dt_soft_plus = (dt_soft_plus > 20 .0f ) ? dt_soft_plus : log1pf (expf (dt_soft_plus));
46+
47+ const float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
6948 float sumf = 0 .0f ;
49+
7050#pragma unroll
7151 for (size_t j = 0 ; j < N; j++) {
72- float state = (smem_s0[tid * stride_ss0 + j] * expf (dt_soft_plus * smem_A[tid * stride_sA + j])) +
73- (B_block[i * stride_B + j] * x_dt);
74- sumf += state * C_block[i * stride_C + j];
75- if (i == L - 1 ) {
76- s_block[tid * stride_s + j] = state;
77- } else {
78- smem_s0[tid * stride_ss0 + j] = state;
79- }
52+ const float exp_term = expf (dt_soft_plus * A[j]);
53+ s[j] = fmaf (s[j], exp_term, B_block[i * stride_B + j] * x_dt);
54+ sumf = fmaf (s[j], C_block[i * stride_C + j], sumf);
8055 }
81- __syncthreads ();
8256 y_block[i * stride_y + tid] = sumf;
8357 }
58+
59+ #pragma unroll
60+ for (int j = 0 ; j < N; j++) {
61+ s_block[tid * stride_s + j] = s[j];
62+ }
8463}
8564
8665// assumes as many threads as d_state
0 commit comments