@@ -37,16 +37,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
3737
3838
3939 // Warp tile
40- const int lane_id = threadIdx .x % 32 ;
41- const int warp_id = threadIdx .x / 32 ;
42- const int mma_tid_x = (lane_id / 2 ) % 8 ;
43- const int mma_tid_y = (lane_id / 16 ) * 2 + (lane_id % 2 );
40+ const int lane_id = threadIdx .x & 31 ;
41+ const int warp_id = threadIdx .x >> 5 ;
42+ const int mma_tid_x = (lane_id >> 1 ) % 8 ;
43+ const int mma_tid_y = (lane_id >> 4 ) * 2 + (lane_id & 1 );
4444 // lds addr
45- int weight_lds_addr = (warp_id / 2 ) * 32 + mma_tid_y * 4 ;
46- int input_lds_addr = (warp_id % 2 ) * 64 + mma_tid_x * 4 ;
45+ int weight_lds_addr = (warp_id >> 1 ) * 32 + mma_tid_y * 4 ;
46+ int input_lds_addr = (warp_id & 1 ) * 64 + mma_tid_x * 4 ;
4747
48- int x = bx * 128 + input_lds_addr;
49- int y = by * 128 + weight_lds_addr;
48+ // int x = bx * 128 + input_lds_addr;
49+ // int y = by * 128 + weight_lds_addr;
5050 int z = blockIdx .z ;
5151
5252 T weight_ldg_reg[4 ];
@@ -56,20 +56,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
5656 int posw_ori[4 ];
5757#pragma unroll
5858 for (int i = 0 ; i < 4 ; ++i){
59- posh_ori[i] = ((bx * 128 + tx % 32 + i * 32 ) / param.Ow ) * param.u - param.p ;
60- posw_ori[i] = ((bx * 128 + tx % 32 + i * 32 ) % param.Ow ) * param.v - param.q ;
59+ posh_ori[i] = ((bx * 128 + lane_id + i * 32 ) / param.Ow ) * param.u - param.p ;
60+ posw_ori[i] = ((bx * 128 + lane_id + i * 32 ) % param.Ow ) * param.v - param.q ;
6161 }
6262
6363 int inOffset = z * param.c * param.h * param.w ;
64- int weiOffset = (by * 128 + tx / 8 * 4 ) * param.c * param.r * param.s ;
64+ int weiOffset = (by * 128 + ( tx >> 3 ) * 4 ) * param.c * param.r * param.s ;
6565 int inChannelOffset = param.h * param.w ;
66- int weightChannelOffset = param.r * param.s ;
66+ // int weightChannelOffset = param.r * param.s;
6767 int weightKOffset = param.c * param.r * param.s ;
6868
6969 // sts addr
70- int weight_sts_addr = (tx % 8 ) * 132 +
71- (tx / 8 ) * 4 ;
72- int input_sts_addr = (tx / 32 ) * 128 + (tx % 32 );
70+ int weight_sts_addr = (tx & 7 ) * 132 +
71+ (tx >> 3 ) * 4 ;
72+ int input_sts_addr = (warp_id ) * 128 + (lane_id );
7373
7474 int write_flag = 1 ;
7575 T weight_frag[2 ][8 ];
@@ -85,16 +85,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
8585// ldg
8686#pragma unroll
8787 for (int i = 0 ; i < 4 ; ++i){
88- if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k ){
89- weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
88+ if (tx % 8 < weightKOffset && by * 128 + ( tx >> 3 ) * 4 + i < param.k ){
89+ weight_ldg_reg[i] = kernel[weiOffset + ( tx & 7 ) + i * weightKOffset];
9090 }
9191 else {
9292 weight_ldg_reg[i] = (T)0 .f ;
9393 }
9494 }
95- int curC = (tx / 32 ) / (param.r * param.s ); // channel offset
96- int curR = ((tx / 32 ) % (param.r * param.s )) / param.s ; // kernel r offset
97- int curS = ((tx / 32 ) % (param.r * param.s )) % param.s ; // kernel s offset
95+ int curC = (warp_id ) / (param.r * param.s ); // channel offset
96+ int curR = ((warp_id ) % (param.r * param.s )) / param.s ; // kernel r offset
97+ int curS = ((warp_id ) % (param.r * param.s )) % param.s ; // kernel s offset
9898#pragma unroll
9999 for (int i = 0 ; i < 4 ; ++i){
100100 int curH = posh_ori[i] + curR * param.d_h ; // input h
@@ -127,21 +127,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
127127 input_frag[0 ][i] = smeminput[input_lds_addr + i];
128128 input_frag[0 ][i + 4 ] = smeminput[input_lds_addr + i + 32 ];
129129 }
130+
131+ // main loop
130132 for (int crs = 0 ; crs < param.r * param.s * param.c ; crs += 8 ){
131133 // ldg
132- int weiOffsetTmp = crs + 8 + tx % 8 ;
134+ int weiOffsetTmp = crs + 8 + ( tx & 7 ) ;
133135#pragma unroll
134136 for (int i = 0 ; i < 4 ; ++i){
135- if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k ){
137+ if (weiOffsetTmp < weightKOffset && by * 128 + ( tx >> 3 ) * 4 + i < param.k ){
136138 weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
137139 }
138140 else {
139141 weight_ldg_reg[i] = (T)0 .f ;
140142 }
141143 }
142- curC = (crs + 8 + tx / 32 ) / (param.r * param.s ); // channel offset
143- curR = ((crs + 8 + tx / 32 ) % (param.r * param.s )) / param.s ; // kernel r offset
144- curS = ((crs + 8 + tx / 32 ) % (param.r * param.s )) % param.s ; // kernel s offset
144+ curC = (crs + 8 + warp_id ) / (param.r * param.s ); // channel offset
145+ curR = ((crs + 8 + warp_id ) % (param.r * param.s )) / param.s ; // kernel r offset
146+ curS = ((crs + 8 + warp_id ) % (param.r * param.s )) % param.s ; // kernel s offset
145147
146148#pragma unroll
147149 for (int i = 0 ; i < 4 ; ++i){
@@ -160,13 +162,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
160162 for (int subcrs = 0 ; subcrs < 8 - 1 ; ++subcrs){
161163#pragma unroll
162164 for (int i = 0 ; i < 4 ; ++i){
163- weight_frag[(subcrs + 1 ) % 2 ][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1 ) * 132 + i];
164- weight_frag[(subcrs + 1 ) % 2 ][i + 4 ] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1 ) * 132 + i + 16 ];
165+ weight_frag[(subcrs + 1 ) & 1 ][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1 ) * 132 + i];
166+ weight_frag[(subcrs + 1 ) & 1 ][i + 4 ] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1 ) * 132 + i + 16 ];
165167 }
168+ // // compute base pointer once
169+ // T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132;
170+
171+ // // first 4 values -> weight_frag[...][0..3]
172+ // float4 v0 = *reinterpret_cast<const float4*>(base_ptr);
173+
174+ // // next 4 values (offset +16) -> weight_frag[...][4..7]
175+ // float4 v1 = *reinterpret_cast<const float4*>(base_ptr + 16);
176+
177+ // // unpack into weight_frag
178+ // *reinterpret_cast<float4*>(&weight_frag[(subcrs + 1) % 2][0]) = v0;
179+ // *reinterpret_cast<float4*>(&weight_frag[(subcrs + 1) % 2][4]) = v1;
166180#pragma unroll
167181 for (int i = 0 ; i < 4 ; ++i){
168- input_frag[(subcrs + 1 ) % 2 ][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1 ) * 128 + i];
169- input_frag[(subcrs + 1 ) % 2 ][i + 4 ] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1 ) * 128 + i + 32 ];
182+ input_frag[(subcrs + 1 ) & 1 ][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1 ) * 128 + i];
183+ input_frag[(subcrs + 1 ) & 1 ][i + 4 ] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1 ) * 128 + i + 32 ];
170184 }
171185
172186#pragma unroll
0 commit comments