Skip to content

Commit 83a3b7d

Browse files
committed
Refactor conv2d_implicit_kernel for improved bitwise operations; add test for implicit convolution
1 parent 4b0f9d5 commit 83a3b7d

File tree

3 files changed

+434
-29
lines changed

3 files changed

+434
-29
lines changed

ggml/src/ggml-cuda/conv2d-implicit.cu

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
3737

3838

3939
// Warp tile
40-
const int lane_id = threadIdx.x % 32;
41-
const int warp_id = threadIdx.x / 32;
42-
const int mma_tid_x = (lane_id / 2) % 8;
43-
const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2);
40+
const int lane_id = threadIdx.x & 31;
41+
const int warp_id = threadIdx.x >> 5;
42+
const int mma_tid_x = (lane_id >> 1) % 8;
43+
const int mma_tid_y = (lane_id >> 4) * 2 + (lane_id & 1);
4444
// lds addr
45-
int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
46-
int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
45+
int weight_lds_addr = (warp_id >> 1) * 32 + mma_tid_y * 4;
46+
int input_lds_addr = (warp_id & 1) * 64 + mma_tid_x * 4;
4747

48-
int x = bx * 128 + input_lds_addr;
49-
int y = by * 128 + weight_lds_addr;
48+
// int x = bx * 128 + input_lds_addr;
49+
// int y = by * 128 + weight_lds_addr;
5050
int z = blockIdx.z;
5151

5252
T weight_ldg_reg[4];
@@ -56,20 +56,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
5656
int posw_ori[4];
5757
#pragma unroll
5858
for (int i = 0; i < 4; ++i){
59-
posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p;
60-
posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q;
59+
posh_ori[i] = ((bx * 128 + lane_id + i * 32) / param.Ow) * param.u - param.p;
60+
posw_ori[i] = ((bx * 128 + lane_id + i * 32) % param.Ow) * param.v - param.q;
6161
}
6262

6363
int inOffset = z * param.c * param.h * param.w;
64-
int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s;
64+
int weiOffset = (by * 128 + (tx >> 3) * 4) * param.c * param.r * param.s;
6565
int inChannelOffset = param.h * param.w;
66-
int weightChannelOffset = param.r * param.s;
66+
// int weightChannelOffset = param.r * param.s;
6767
int weightKOffset = param.c * param.r * param.s;
6868

6969
// sts addr
70-
int weight_sts_addr = (tx % 8) * 132 +
71-
(tx / 8) * 4;
72-
int input_sts_addr = (tx / 32) * 128 + (tx % 32);
70+
int weight_sts_addr = (tx & 7) * 132 +
71+
(tx >> 3) * 4;
72+
int input_sts_addr = (warp_id) * 128 + (lane_id);
7373

7474
int write_flag = 1;
7575
T weight_frag[2][8];
@@ -85,16 +85,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
8585
// ldg
8686
#pragma unroll
8787
for (int i = 0; i < 4; ++i){
88-
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
89-
weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
88+
if (tx % 8 < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){
89+
weight_ldg_reg[i] = kernel[weiOffset + (tx & 7) + i * weightKOffset];
9090
}
9191
else{
9292
weight_ldg_reg[i] = (T)0.f;
9393
}
9494
}
95-
int curC = (tx / 32) / (param.r * param.s); // channel offset
96-
int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
97-
int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
95+
int curC = (warp_id) / (param.r * param.s); // channel offset
96+
int curR = ((warp_id) % (param.r * param.s)) / param.s; // kernel r offset
97+
int curS = ((warp_id) % (param.r * param.s)) % param.s; // kernel s offset
9898
#pragma unroll
9999
for (int i = 0; i < 4; ++i){
100100
int curH = posh_ori[i] + curR * param.d_h; // input h
@@ -127,21 +127,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
127127
input_frag[0][i] = smeminput[input_lds_addr + i];
128128
input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32];
129129
}
130+
131+
// main loop
130132
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){
131133
// ldg
132-
int weiOffsetTmp = crs + 8 + tx % 8;
134+
int weiOffsetTmp = crs + 8 + (tx & 7);
133135
#pragma unroll
134136
for (int i = 0; i < 4; ++i){
135-
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
137+
if (weiOffsetTmp < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){
136138
weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
137139
}
138140
else{
139141
weight_ldg_reg[i] = (T)0.f;
140142
}
141143
}
142-
curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset
143-
curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
144-
curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
144+
curC = (crs + 8 + warp_id) / (param.r * param.s); // channel offset
145+
curR = ((crs + 8 + warp_id) % (param.r * param.s)) / param.s; // kernel r offset
146+
curS = ((crs + 8 + warp_id) % (param.r * param.s)) % param.s; // kernel s offset
145147

146148
#pragma unroll
147149
for (int i = 0; i < 4; ++i){
@@ -160,13 +162,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
160162
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){
161163
#pragma unroll
162164
for (int i = 0; i < 4; ++i){
163-
weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i];
164-
weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16];
165+
weight_frag[(subcrs + 1) & 1][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i];
166+
weight_frag[(subcrs + 1) & 1][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16];
165167
}
168+
// // compute base pointer once
169+
// T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132;
170+
171+
// // first 4 values -> weight_frag[...][0..3]
172+
// float4 v0 = *reinterpret_cast<const float4*>(base_ptr);
173+
174+
// // next 4 values (offset +16) -> weight_frag[...][4..7]
175+
// float4 v1 = *reinterpret_cast<const float4*>(base_ptr + 16);
176+
177+
// // unpack into weight_frag
178+
// *reinterpret_cast<float4*>(&weight_frag[(subcrs + 1) % 2][0]) = v0;
179+
// *reinterpret_cast<float4*>(&weight_frag[(subcrs + 1) % 2][4]) = v1;
166180
#pragma unroll
167181
for (int i = 0; i < 4; ++i){
168-
input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i];
169-
input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
182+
input_frag[(subcrs + 1) & 1][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i];
183+
input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
170184
}
171185

172186
#pragma unroll

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS)
198198
endif()
199199
llama_build_and_test(test-gguf.cpp)
200200
llama_build_and_test(test-backend-ops.cpp)
201+
llama_build_and_test(test-conv2d-implicit.cpp)
201202

202203
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
203204
llama_build_and_test(test-autorelease.cpp LABEL "model")

0 commit comments

Comments
 (0)