- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.4k
Faster ssm scan #10558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
        
      
    
  
     Merged
                    Faster ssm scan #10558
Changes from 3 commits
      Commits
    
    
            Show all changes
          
          
            8 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      65180fb
              
                faster ssm_scan
              
              
                A3shTnT 6a6c954
              
                delete unused commnet
              
              
                A3shTnT 1e64567
              
                clang format
              
              
                A3shTnT 828e4f7
              
                add space
              
              
                A3shTnT e52a22d
              
                modify unnecessary calculations
              
              
                A3shTnT 0dd48a6
              
                faster ssm conv implementatioin
              
              
                A3shTnT c009e89
              
                modify file name with dash
              
              
                A3shTnT c9a07d2
              
                Merge remote-tracking branch 'origin' into faster_ssm_scan
              
              
                A3shTnT File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| #include "ssm_conv.cuh" | ||
|  | ||
| template <int block_size> | ||
| static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, | ||
| const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, | ||
| float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
| const int nc, const int ncs, const int nr, const int n_t, const int n_s) { | ||
| const int tid = blockIdx.y; | ||
| const int i3 = blockIdx.x; | ||
| const int i2 = threadIdx.x; | ||
|  | ||
| const int ith = tid; | ||
| const int nth = WARP_SIZE; | ||
|  | ||
| // rows per thread | ||
| const int dr = (nr + nth - 1) / nth; | ||
|  | ||
| // row range for this thread | ||
| const int ir0 = dr * ith; | ||
| const int ir1 = min(ir0 + dr, nr); | ||
| const int ir = ir1 - ir0; | ||
|  | ||
| // {d_conv - 1 + n_t, d_inner, n_seqs} | ||
| // sliding window | ||
| const float * s = (const float *) ((const char *) src0 + ir0 * src0_nb1 + i2 * src0_nb0 + | ||
| i3 * src0_nb2); // {d_conv, d_inner, n_s} | ||
| const float * c = (const float *) ((const char *) src1 + ir0 * src1_nb1); // {d_conv, d_inner} | ||
| float * x = (float *) ((char *) dst + ir0 * dst_nb0 + i2 * dst_nb1 + i3 * dst_nb2); // {d_inner, n_t, n_s} | ||
|  | ||
| // TODO: transpose the output for smaller strides for big batches? | ||
| // d_inner | ||
| for (int i1 = 0; i1 < ir; ++i1) { | ||
| // rowwise dot product | ||
| // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision | ||
| float sumf = 0.0f; | ||
|  | ||
| // d_conv | ||
| #pragma unroll | ||
| for (int i0 = 0; i0 < nc; ++i0) { | ||
| sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc]; | ||
| } | ||
| x[i1] = sumf; | ||
| } | ||
| } | ||
|  | ||
| static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, | ||
| const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, | ||
| const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, | ||
| const int n_s, cudaStream_t stream) { | ||
| const dim3 block_dims(n_t, 1, 1); | ||
| // const int nblocks = n_s; // TODO | ||
| const dim3 grid_dims(n_s, WARP_SIZE, 1); | ||
|  | ||
| ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>( | ||
| src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s); | ||
| } | ||
|  | ||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const struct ggml_tensor * src0 = dst->src[0]; // conv_x | ||
| const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight | ||
|  | ||
| const int nc = src1->ne[0]; // d_conv | ||
| const int ncs = src0->ne[0]; // d_conv - 1 + n_t | ||
| const int nr = src0->ne[1]; // d_inner | ||
| const int n_t = dst->ne[1]; // tokens per sequence | ||
| const int n_s = dst->ne[2]; // number of sequences in the batch | ||
|  | ||
| GGML_ASSERT(dst->ne[0] == nr); | ||
| GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); | ||
|  | ||
| const float * src0_d = (const float *) src0->data; | ||
| const float * src1_d = (const float *) src1->data; | ||
| float * dst_d = (float *) dst->data; | ||
| cudaStream_t stream = ctx.stream(); | ||
|  | ||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
| ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], | ||
| dst->nb[2], nc, ncs, nr, n_t, n_s, stream); | ||
| } | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| #include "common.cuh" | ||
|  | ||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| #include "ssm_scan.cuh" | ||
|  | ||
| // #include <cuda_runtime.h> | ||
| // static __device__ void global_to_shared(const float *src, float *dst) { | ||
| // asm volatile("cp.async."); | ||
| // } | ||
|  | ||
| template <size_t splitD, size_t N> | ||
| __global__ void __launch_bounds__(splitD, 2) | ||
| ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, | ||
| const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, | ||
| const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, | ||
| const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
| const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
| float * __restrict__ dst, const int D, const int L, const int B) { | ||
| const int bidx = blockIdx.x; // split along B | ||
| const int bidy = blockIdx.y; // split along D | ||
| const int tid = threadIdx.x; | ||
| const int wid = tid / 32; | ||
| const int wtid = tid % 32; | ||
|  | ||
| extern __shared__ float smem[]; | ||
| const int stride_sA = N + 1; | ||
| const int stride_ss0 = N + 1; | ||
| float * smem_A = smem; | ||
| float * smem_s0 = smem_A + splitD * stride_sA; | ||
|  | ||
| const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
| const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
| const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); | ||
| const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1); | ||
| const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2)); | ||
| const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2)); | ||
| float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); | ||
| float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); | ||
|  | ||
| const int stride_s0 = src0_nb1 / sizeof(float); | ||
| const int stride_x = src1_nb1 / sizeof(float); | ||
| const int stride_dt = src2_nb1 / sizeof(float); | ||
| const int stride_A = src3_nb1 / sizeof(float); | ||
| const int stride_B = src4_nb1 / sizeof(float); | ||
| const int stride_C = src5_nb1 / sizeof(float); | ||
| const int stride_s = stride_s0; | ||
| const int stride_y = stride_x; | ||
|  | ||
| // can N not be 16? for example 32? | ||
| if (N == 16) { | ||
| #pragma unroll | ||
| for (int i = 0; i < splitD / 4; i += 2) { | ||
| float value = A_block[(wid * warpSize + i) * stride_A + wtid]; | ||
| // todo: bank conflict | ||
| // I am always confused with how to use the swizzling method to solve | ||
| // bank conflit. Hoping somebody can tell me. | ||
| smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
| } | ||
| #pragma unroll | ||
| for (int i = 0; i < splitD / 4; i += 2) { | ||
| float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; | ||
| smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; | ||
| } | ||
| } | ||
|  | ||
| __syncthreads(); | ||
|  | ||
| for (int i = 0; i < L; i++) { | ||
| float dt_soft_plus = dt_block[i * stride_dt + wid * warpSize + wtid]; | ||
| if (dt_soft_plus <= 20.0f) { | ||
| dt_soft_plus = log1pf(exp(dt_soft_plus)); | ||
| } | ||
| float x_dt = x_block[i * stride_x + wid * warpSize + wtid] * dt_soft_plus; | ||
| float sumf = 0.0f; | ||
| #pragma unroll | ||
| for (int j = 0; j < N; j++) { | ||
| float state = (smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] * | ||
| expf(dt_soft_plus * smem_A[(wid * warpSize + wtid) * stride_sA + j])) + | ||
| (B_block[i * stride_B + j] * x_dt); | ||
| sumf += state * C_block[i * stride_C + j]; | ||
| if (i == L - 1) { | ||
| s_block[(wid * warpSize + wtid) * stride_s + j] = state; | ||
| } else { | ||
| smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] = state; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| y_block[i * stride_y + wid * warpSize + wtid] = sumf; | ||
| } | ||
| } | ||
|  | ||
| static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, | ||
| const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, | ||
| const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
| const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||
| const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, | ||
| float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) { | ||
| const int threads = 128; | ||
| // todo: consider D cannot be divided,does this situation exist? | ||
| GGML_ASSERT(D % threads == 0); | ||
| const dim3 blocks(B, (D + threads - 1) / threads, 1); | ||
| const int smem_size = (threads * (N + 1) * 2) * sizeof(float); | ||
| if (N == 16) { | ||
| ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( | ||
| src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, | ||
| src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B); | ||
| } else { | ||
| GGML_ABORT("doesn't support N!=16."); | ||
| } | ||
| } | ||
|  | ||
| void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const struct ggml_tensor * src0 = dst->src[0]; // s | ||
| const struct ggml_tensor * src1 = dst->src[1]; // x | ||
| const struct ggml_tensor * src2 = dst->src[2]; // dt | ||
| const struct ggml_tensor * src3 = dst->src[3]; // A | ||
| const struct ggml_tensor * src4 = dst->src[4]; // B | ||
| const struct ggml_tensor * src5 = dst->src[5]; // C | ||
|  | ||
| // const int64_t d_state = src0->ne[0]; | ||
| // const int64_t d_inner = src0->ne[1]; | ||
| // const int64_t l = src1->ne[1]; | ||
| // const int64_t b = src0->ne[2]; | ||
|  | ||
| const int64_t nc = src0->ne[0]; // d_state | ||
| const int64_t nr = src0->ne[1]; // d_inner | ||
| const int64_t n_t = src1->ne[1]; // number of tokens per sequence | ||
| const int64_t n_s = src0->ne[2]; // number of sequences in the batch | ||
|  | ||
| GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); | ||
| GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src2->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src3->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src4->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src5->nb[0] == sizeof(float)); | ||
| // required for the dot product between s and C | ||
| GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); | ||
| // required for per-sequence offsets for states | ||
| GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); | ||
| // required to get correct offset for state destination (i.e. src1->nb[3]) | ||
| GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); | ||
|  | ||
| const float * src0_d = (const float *) src0->data; | ||
| const float * src1_d = (const float *) src1->data; | ||
| const float * src2_d = (const float *) src2->data; | ||
| const float * src3_d = (const float *) src3->data; | ||
| const float * src4_d = (const float *) src4->data; | ||
| const float * src5_d = (const float *) src5->data; | ||
| float * dst_d = (float *) dst->data; | ||
| cudaStream_t stream = ctx.stream(); | ||
|  | ||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|  | ||
| ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], | ||
| src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], | ||
| src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); | ||
| } | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| #include "common.cuh" | ||
|  | ||
| void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.