File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
1010 float * __restrict__ dst, const int64_t L) {
1111 GGML_UNUSED (src1_nb0);
1212 GGML_UNUSED (src2_nb0);
13+
14+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
1315 const int bidx = blockIdx .x ; // split along B
1416 const int bidy = blockIdx .y ; // split along D
1517 const int tid = threadIdx .x ;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
4446 if (N == 16 ) {
4547#pragma unroll
4648 for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
47- float value = A_block[(wid * warpSize + i) * stride_A + wtid];
49+ float value = A_block[(wid * warp_size + i) * stride_A + wtid];
4850 // todo: bank conflict
4951 // I am always confused with how to use the swizzling method to solve
5052 // bank conflit. Hoping somebody can tell me.
51- smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
53+ smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
5254 }
5355#pragma unroll
5456 for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
55- float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56- smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
57+ float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
58+ smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
5759 }
5860 }
5961
You can’t perform that action at this time.
0 commit comments