diff --git a/gpu/bitnet_kernels/bitnet_kernels.h b/gpu/bitnet_kernels/bitnet_kernels.h index 1d897908..365ad28f 100644 --- a/gpu/bitnet_kernels/bitnet_kernels.h +++ b/gpu/bitnet_kernels/bitnet_kernels.h @@ -49,14 +49,14 @@ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restric constexpr int wmma_K = 32; constexpr int wmma_N = 16; int in_thread_C_local[1]; - signed char A_local[K_per_loop]; + alignas(16) signed char A_local[K_per_loop]; int B_reshape_local[1]; signed char B_decode_local[K_per_loop]; int red_buf0[1]; in_thread_C_local[0] = 0; #pragma unroll for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) { - *(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop))); + *(int4*)(A_local + 0) = *(int4*)(A + blockIdx.y * K + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop))); B_reshape_local[0] = *(int*)(B + (((int)blockIdx.x) * N_block_size * K / 4) + (k_0 * K_block_size * K_per_loop * wmma_N / 4) + @@ -76,8 +76,8 @@ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restric for (int offset = K_block_size/2; offset > 0; offset /= 2) { red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size); } - int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y)); + int out_idx = ( blockIdx.y * K + (((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y)); int ws_idx = out_idx / (N / ws_num); if (threadIdx.x == 0) dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]); -} \ No newline at end of file +}