22#include " convert.cuh"
33
44typedef struct {
5- unsigned int n; // batch szie
6- unsigned int c; // channel number
5+ unsigned int n; // batch size
6+ unsigned int c; // number if channels
77 unsigned int h; // height
88 unsigned int w; // width
99 unsigned int k; // number of filters
@@ -23,23 +23,18 @@ typedef struct{
2323
2424template <typename T>
2525static __global__ void conv2d_implicit_kernel (const float * __restrict__ input,
26- const T * __restrict__ kernel,
27- float * __restrict__ output,
28- const param_t param) {
26+ const T * __restrict__ kernel,
27+ float * __restrict__ output,
28+ const param_t param) {
2929
30- extern __shared__ __align__ ( 16 * 1024 ) char smem[];
30+ extern __shared__ unsigned char smem[];
3131 T *smemweight = reinterpret_cast <T *>(smem);
3232 float *smeminput = reinterpret_cast <float *>(smem + 16 * 1024 );
3333
3434 int tx = threadIdx .x ;
3535 int bx = blockIdx .x ;
3636 int by = blockIdx .y ;
37-
38- // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
39- // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
40- // // printf("param.n=%d\n",param.n);
41- // }
42- // __syncthreads();
37+
4338
4439 // Warp tile
4540 const int lane_id = threadIdx .x % 32 ;
@@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
6055 int posh_ori[4 ];
6156 int posw_ori[4 ];
6257#pragma unroll
63- for (int i = 0 ; i < 4 ; ++i)
64- {
58+ for (int i = 0 ; i < 4 ; ++i){
6559 posh_ori[i] = ((bx * 128 + tx % 32 + i * 32 ) / param.Ow ) * param.u - param.p ;
6660 posw_ori[i] = ((bx * 128 + tx % 32 + i * 32 ) % param.Ow ) * param.v - param.q ;
6761 }
@@ -82,86 +76,66 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
8276 float input_frag[2 ][8 ];
8377 float output_frag[8 ][8 ];
8478#pragma unroll
85- for (int i = 0 ; i < 8 ; ++i)
86- {
79+ for (int i = 0 ; i < 8 ; ++i){
8780#pragma unroll
88- for (int j = 0 ; j < 8 ; ++j)
89- {
81+ for (int j = 0 ; j < 8 ; ++j){
9082 output_frag[i][j] = 0 ;
9183 }
9284 }
9385// ldg
94- // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
95- // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
96- // }
97- // __syncthreads();
9886#pragma unroll
99- for (int i = 0 ; i < 4 ; ++i)
100- {
101- if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k )
102- {
87+ for (int i = 0 ; i < 4 ; ++i){
88+ if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k ){
10389 weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
10490 }
105- else
106- {
91+ else {
10792 weight_ldg_reg[i] = (T)0 .f ;
10893 }
10994 }
11095 int curC = (tx / 32 ) / (param.r * param.s ); // channel offset
11196 int curR = ((tx / 32 ) % (param.r * param.s )) / param.s ; // kernel r offset
11297 int curS = ((tx / 32 ) % (param.r * param.s )) % param.s ; // kernel s offset
11398#pragma unroll
114- for (int i = 0 ; i < 4 ; ++i)
115- {
99+ for (int i = 0 ; i < 4 ; ++i){
116100 int curH = posh_ori[i] + curR * param.d_h ; // input h
117101 int curW = posw_ori[i] + curS * param.d_w ; // input w
118102 int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
119- if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c )
120- {
103+ if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c ){
121104 input_ldg_reg[i] = input[inOffset + inOffsetTmp];
122105 }
123- else
124- {
106+ else {
125107 input_ldg_reg[i] = 0.0 ;
126108 }
127109 }
128110 // sts
129- for (int i = 0 ; i < 4 ; ++i)
130- {
111+ for (int i = 0 ; i < 4 ; ++i){
131112 smemweight[weight_sts_addr + i] = weight_ldg_reg[i];
132113 }
133- for (int i = 0 ; i < 4 ; ++i)
134- {
114+ for (int i = 0 ; i < 4 ; ++i){
135115 smeminput[input_sts_addr + i * 32 ] = input_ldg_reg[i];
136116 }
137117
138118 __syncthreads ();
139119 // lds
140120#pragma unroll
141- for (int i = 0 ; i < 4 ; ++i)
142- {
121+ for (int i = 0 ; i < 4 ; ++i){
143122 weight_frag[0 ][i] = smemweight[weight_lds_addr + i];
144123 weight_frag[0 ][i + 4 ] = smemweight[weight_lds_addr + i + 16 ];
145124 }
146125#pragma unroll
147- for (int i = 0 ; i < 4 ; ++i)
148- {
126+ for (int i = 0 ; i < 4 ; ++i){
149127 input_frag[0 ][i] = smeminput[input_lds_addr + i];
150128 input_frag[0 ][i + 4 ] = smeminput[input_lds_addr + i + 32 ];
151129 }
152- for (int crs = 0 ; crs < param.r * param.s * param.c ; crs += 8 )
153- {
130+ for (int crs = 0 ; crs < param.r * param.s * param.c ; crs += 8 ){
154131 // ldg
155132 int weiOffsetTmp = crs + 8 + tx % 8 ;
156133#pragma unroll
157- for (int i = 0 ; i < 4 ; ++i)
158- {
159- if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k )
160- {
134+ for (int i = 0 ; i < 4 ; ++i){
135+ if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k ){
161136 weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
162137 }
163- else
164- {
138+ else {
165139 weight_ldg_reg[i] = (T)0 .f ;
166140 }
167141 }
@@ -170,133 +144,104 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
170144 curS = ((crs + 8 + tx / 32 ) % (param.r * param.s )) % param.s ; // kernel s offset
171145
172146#pragma unroll
173- for (int i = 0 ; i < 4 ; ++i)
174- {
147+ for (int i = 0 ; i < 4 ; ++i){
175148 int curH = posh_ori[i] + curR * param.d_h ; // input h
176149 int curW = posw_ori[i] + curS * param.d_w ; // input w
177150 int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
178- if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c )
179- {
151+ if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c ){
180152 input_ldg_reg[i] = input[inOffset + inOffsetTmp];
181153 }
182- else
183- {
154+ else {
184155 input_ldg_reg[i] = 0 .f ;
185156 }
186157 }
187158 int load_flag = write_flag ^ 1 ;
188159#pragma unroll
189- for (int subcrs = 0 ; subcrs < 8 - 1 ; ++subcrs)
190- {
160+ for (int subcrs = 0 ; subcrs < 8 - 1 ; ++subcrs){
191161#pragma unroll
192- for (int i = 0 ; i < 4 ; ++i)
193- {
162+ for (int i = 0 ; i < 4 ; ++i){
194163 weight_frag[(subcrs + 1 ) % 2 ][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1 ) * 132 + i];
195164 weight_frag[(subcrs + 1 ) % 2 ][i + 4 ] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1 ) * 132 + i + 16 ];
196165 }
197166#pragma unroll
198- for (int i = 0 ; i < 4 ; ++i)
199- {
167+ for (int i = 0 ; i < 4 ; ++i){
200168 input_frag[(subcrs + 1 ) % 2 ][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1 ) * 128 + i];
201169 input_frag[(subcrs + 1 ) % 2 ][i + 4 ] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1 ) * 128 + i + 32 ];
202170 }
203171
204172#pragma unroll
205- for (int i = 0 ; i < 8 ; ++i)
206- {
173+ for (int i = 0 ; i < 8 ; ++i){
207174#pragma unroll
208- for (int j = 0 ; j < 8 ; ++j)
209- {
175+ for (int j = 0 ; j < 8 ; ++j){
210176 output_frag[i][j] += ggml_cuda_cast<float >(weight_frag[subcrs % 2 ][i]) * input_frag[subcrs % 2 ][j];
211177 }
212178 }
213179 }
214180 // sts
215- for (int i = 0 ; i < 4 ; ++i)
216- {
181+ for (int i = 0 ; i < 4 ; ++i){
217182 smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i];
218183 }
219- for (int i = 0 ; i < 4 ; ++i)
220- {
184+ for (int i = 0 ; i < 4 ; ++i){
221185 smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32 ] = input_ldg_reg[i];
222186 }
223187 __syncthreads ();
224188 write_flag ^= 1 ;
225189#pragma unroll
226- for (int i = 0 ; i < 4 ; ++i)
227- {
190+ for (int i = 0 ; i < 4 ; ++i){
228191 weight_frag[0 ][i] = smemweight[(load_flag ^ 1 ) * 132 * 8 + weight_lds_addr + i];
229192 weight_frag[0 ][i + 4 ] = smemweight[(load_flag ^ 1 ) * 132 * 8 + weight_lds_addr + i + 16 ];
230193 }
231194#pragma unroll
232- for (int i = 0 ; i < 4 ; ++i)
233- {
195+ for (int i = 0 ; i < 4 ; ++i){
234196 input_frag[0 ][i] = smeminput[(load_flag ^ 1 ) * 128 * 8 + input_lds_addr + i];
235197 input_frag[0 ][i + 4 ] = smeminput[(load_flag ^ 1 ) * 128 * 8 + input_lds_addr + i + 32 ];
236198 }
237199#pragma unroll
238- for (int i = 0 ; i < 8 ; ++i)
239- {
200+ for (int i = 0 ; i < 8 ; ++i){
240201#pragma unroll
241- for (int j = 0 ; j < 8 ; ++j)
242- {
202+ for (int j = 0 ; j < 8 ; ++j){
243203 output_frag[i][j] += ggml_cuda_cast<float >(weight_frag[1 ][i]) * input_frag[1 ][j];
244204 }
245205 }
246206 }
247207
248208 // reuse smem
249209 float *smemoutput = reinterpret_cast <float *>(smem);
250- // float *smembias = reinterpret_cast<float *>(smem + 16 * 1024);
251210
252- // bias ldg/sts
253- // if (tx < 128)
254- // {
255- // smembias[tx] = param.bias[by * 128 + tx];
256- // }
257211
258212 uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4 ;
259213 uint32_t output_lds_addr = warp_id * 512 + lane_id;
260- // uint32_t bias_lds_addr = warp_id / 2 * 32;
261214
262215 uint32_t m_idx = blockIdx .y * 128 + warp_id / 2 * 32 ;
263216 uint32_t n_idx = blockIdx .x * 128 + warp_id % 2 * 64 + lane_id;
264217
265218#pragma unroll
266- for (int i = 0 ; i < 2 ; ++i)
267- {
219+ for (int i = 0 ; i < 2 ; ++i){
268220#pragma unroll
269- for (int j = 0 ; j < 2 ; ++j)
270- {
221+ for (int j = 0 ; j < 2 ; ++j){
271222 __syncthreads ();
272-
273223#pragma unroll
274- for (int subi = 0 ; subi < 4 ; ++subi)
275- {
224+ for (int subi = 0 ; subi < 4 ; ++subi){
276225#pragma unroll
277- for (int subj = 0 ; subj < 4 ; ++subj)
278- {
226+ for (int subj = 0 ; subj < 4 ; ++subj){
279227 // output sts
280228 smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
281229 }
282230 }
283231 __syncthreads ();
284232
285233#pragma unroll
286- for (int subk = 0 ; subk < 16 ; ++subk)
287- {
234+ for (int subk = 0 ; subk < 16 ; ++subk){
288235 int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32 ;
289236 if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32 ) < param.Oh * param.Ow )
290- // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk];
291237 output[outOffset] = smemoutput[output_lds_addr + subk * 32 ];
292238 }
293239 }
294240 }
295241}
296242
297243template <typename T>
298- static void conv2d_implicit_cuda (const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
299- // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
244+ static void conv2d_implicit_cuda (const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
300245 int blockx = ((P.Oh * P.Ow + 127 ) / 128 ); // blockx number
301246 int blocky = (P.k + 127 ) / 128 ; // blocky number
302247 int blockz = P.n ; // blockz number
0 commit comments