Skip to content

Commit 4b0f9d5

Browse files
committed
Refactor conv2d_implicit_kernel for improved readability and consistency; update parameter comments and remove unused code
1 parent 5ffe97b commit 4b0f9d5

File tree

1 file changed

+44
-99
lines changed

1 file changed

+44
-99
lines changed

ggml/src/ggml-cuda/conv2d-implicit.cu

Lines changed: 44 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include "convert.cuh"
33

44
typedef struct{
5-
unsigned int n; //batch szie
6-
unsigned int c; //channel number
5+
unsigned int n; //batch size
6+
unsigned int c; //number if channels
77
unsigned int h; //height
88
unsigned int w; //width
99
unsigned int k; //number of filters
@@ -23,23 +23,18 @@ typedef struct{
2323

2424
template <typename T>
2525
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
26-
const T * __restrict__ kernel,
27-
float * __restrict__ output,
28-
const param_t param) {
26+
const T * __restrict__ kernel,
27+
float * __restrict__ output,
28+
const param_t param) {
2929

30-
extern __shared__ __align__(16 * 1024) char smem[];
30+
extern __shared__ unsigned char smem[];
3131
T *smemweight = reinterpret_cast<T *>(smem);
3232
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
3333

3434
int tx = threadIdx.x;
3535
int bx = blockIdx.x;
3636
int by = blockIdx.y;
37-
38-
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
39-
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
40-
// // printf("param.n=%d\n",param.n);
41-
// }
42-
// __syncthreads();
37+
4338

4439
// Warp tile
4540
const int lane_id = threadIdx.x % 32;
@@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
6055
int posh_ori[4];
6156
int posw_ori[4];
6257
#pragma unroll
63-
for (int i = 0; i < 4; ++i)
64-
{
58+
for (int i = 0; i < 4; ++i){
6559
posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p;
6660
posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q;
6761
}
@@ -82,86 +76,66 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
8276
float input_frag[2][8];
8377
float output_frag[8][8];
8478
#pragma unroll
85-
for (int i = 0; i < 8; ++i)
86-
{
79+
for (int i = 0; i < 8; ++i){
8780
#pragma unroll
88-
for (int j = 0; j < 8; ++j)
89-
{
81+
for (int j = 0; j < 8; ++j){
9082
output_frag[i][j] = 0;
9183
}
9284
}
9385
// ldg
94-
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
95-
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
96-
// }
97-
// __syncthreads();
9886
#pragma unroll
99-
for (int i = 0; i < 4; ++i)
100-
{
101-
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
102-
{
87+
for (int i = 0; i < 4; ++i){
88+
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
10389
weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
10490
}
105-
else
106-
{
91+
else{
10792
weight_ldg_reg[i] = (T)0.f;
10893
}
10994
}
11095
int curC = (tx / 32) / (param.r * param.s); // channel offset
11196
int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
11297
int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
11398
#pragma unroll
114-
for (int i = 0; i < 4; ++i)
115-
{
99+
for (int i = 0; i < 4; ++i){
116100
int curH = posh_ori[i] + curR * param.d_h; // input h
117101
int curW = posw_ori[i] + curS * param.d_w; // input w
118102
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
119-
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c)
120-
{
103+
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){
121104
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
122105
}
123-
else
124-
{
106+
else{
125107
input_ldg_reg[i] = 0.0;
126108
}
127109
}
128110
// sts
129-
for (int i = 0; i < 4; ++i)
130-
{
111+
for (int i = 0; i < 4; ++i){
131112
smemweight[weight_sts_addr + i] = weight_ldg_reg[i];
132113
}
133-
for (int i = 0; i < 4; ++i)
134-
{
114+
for (int i = 0; i < 4; ++i){
135115
smeminput[input_sts_addr + i * 32] = input_ldg_reg[i];
136116
}
137117

138118
__syncthreads();
139119
// lds
140120
#pragma unroll
141-
for (int i = 0; i < 4; ++i)
142-
{
121+
for (int i = 0; i < 4; ++i){
143122
weight_frag[0][i] = smemweight[weight_lds_addr + i];
144123
weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16];
145124
}
146125
#pragma unroll
147-
for (int i = 0; i < 4; ++i)
148-
{
126+
for (int i = 0; i < 4; ++i){
149127
input_frag[0][i] = smeminput[input_lds_addr + i];
150128
input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32];
151129
}
152-
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8)
153-
{
130+
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){
154131
// ldg
155132
int weiOffsetTmp = crs + 8 + tx % 8;
156133
#pragma unroll
157-
for (int i = 0; i < 4; ++i)
158-
{
159-
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
160-
{
134+
for (int i = 0; i < 4; ++i){
135+
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
161136
weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
162137
}
163-
else
164-
{
138+
else{
165139
weight_ldg_reg[i] = (T)0.f;
166140
}
167141
}
@@ -170,133 +144,104 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
170144
curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
171145

172146
#pragma unroll
173-
for (int i = 0; i < 4; ++i)
174-
{
147+
for (int i = 0; i < 4; ++i){
175148
int curH = posh_ori[i] + curR * param.d_h; // input h
176149
int curW = posw_ori[i] + curS * param.d_w; // input w
177150
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
178-
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c)
179-
{
151+
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){
180152
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
181153
}
182-
else
183-
{
154+
else{
184155
input_ldg_reg[i] = 0.f;
185156
}
186157
}
187158
int load_flag = write_flag ^ 1;
188159
#pragma unroll
189-
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs)
190-
{
160+
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){
191161
#pragma unroll
192-
for (int i = 0; i < 4; ++i)
193-
{
162+
for (int i = 0; i < 4; ++i){
194163
weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i];
195164
weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16];
196165
}
197166
#pragma unroll
198-
for (int i = 0; i < 4; ++i)
199-
{
167+
for (int i = 0; i < 4; ++i){
200168
input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i];
201169
input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
202170
}
203171

204172
#pragma unroll
205-
for (int i = 0; i < 8; ++i)
206-
{
173+
for (int i = 0; i < 8; ++i){
207174
#pragma unroll
208-
for (int j = 0; j < 8; ++j)
209-
{
175+
for (int j = 0; j < 8; ++j){
210176
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
211177
}
212178
}
213179
}
214180
// sts
215-
for (int i = 0; i < 4; ++i)
216-
{
181+
for (int i = 0; i < 4; ++i){
217182
smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i];
218183
}
219-
for (int i = 0; i < 4; ++i)
220-
{
184+
for (int i = 0; i < 4; ++i){
221185
smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i];
222186
}
223187
__syncthreads();
224188
write_flag ^= 1;
225189
#pragma unroll
226-
for (int i = 0; i < 4; ++i)
227-
{
190+
for (int i = 0; i < 4; ++i){
228191
weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i];
229192
weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16];
230193
}
231194
#pragma unroll
232-
for (int i = 0; i < 4; ++i)
233-
{
195+
for (int i = 0; i < 4; ++i){
234196
input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i];
235197
input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32];
236198
}
237199
#pragma unroll
238-
for (int i = 0; i < 8; ++i)
239-
{
200+
for (int i = 0; i < 8; ++i){
240201
#pragma unroll
241-
for (int j = 0; j < 8; ++j)
242-
{
202+
for (int j = 0; j < 8; ++j){
243203
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
244204
}
245205
}
246206
}
247207

248208
// reuse smem
249209
float *smemoutput = reinterpret_cast<float *>(smem);
250-
// float *smembias = reinterpret_cast<float *>(smem + 16 * 1024);
251210

252-
// bias ldg/sts
253-
// if (tx < 128)
254-
// {
255-
// smembias[tx] = param.bias[by * 128 + tx];
256-
// }
257211

258212
uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4;
259213
uint32_t output_lds_addr = warp_id * 512 + lane_id;
260-
// uint32_t bias_lds_addr = warp_id / 2 * 32;
261214

262215
uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32;
263216
uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id;
264217

265218
#pragma unroll
266-
for (int i = 0; i < 2; ++i)
267-
{
219+
for (int i = 0; i < 2; ++i){
268220
#pragma unroll
269-
for (int j = 0; j < 2; ++j)
270-
{
221+
for (int j = 0; j < 2; ++j){
271222
__syncthreads();
272-
273223
#pragma unroll
274-
for (int subi = 0; subi < 4; ++subi)
275-
{
224+
for (int subi = 0; subi < 4; ++subi){
276225
#pragma unroll
277-
for (int subj = 0; subj < 4; ++subj)
278-
{
226+
for (int subj = 0; subj < 4; ++subj){
279227
// output sts
280228
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
281229
}
282230
}
283231
__syncthreads();
284232

285233
#pragma unroll
286-
for (int subk = 0; subk < 16; ++subk)
287-
{
234+
for (int subk = 0; subk < 16; ++subk){
288235
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
289236
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
290-
// output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk];
291237
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
292238
}
293239
}
294240
}
295241
}
296242

297243
template <typename T>
298-
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
299-
// const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
244+
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
300245
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
301246
int blocky = (P.k + 127) / 128; // blocky number
302247
int blockz = P.n; // blockz number

0 commit comments

Comments
 (0)