@@ -163,7 +163,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
163163 const uint inKOffset = start_k + innerColA * 4 ;
164164#pragma unroll
165165 for (uint offset = 0 ; offset + rowStrideA <= BM; offset += rowStrideA) {
166- const unsigned int gemm_i = bx * BM + innerRowA + offset;
166+ const unsigned int gemm_i = bx * BM + innerRowA + offset;
167167 // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
168168 int n = (ksplit > 0 ) ? fastdiv (gemm_i, param.PQZ_fastdiv ) : z;
169169 const unsigned int npqz_res = fastmodulo (gemm_i, param.PQZ_fastdiv );
@@ -173,26 +173,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
173173 const int posw_ori = fastmodulo (ohow_res, param.OW_fastdiv ) * param.stride0 - param.padding0 ;
174174 int inOffset = n * inNOffset;
175175 if (vec_load){
176- // const uint cur0 = fastdiv(inKOffset,
177- // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
178- // const uint cur0_res = fastmodulo(inKOffset,
179- // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
180- // const uint cur1 = fastdiv(cur0_res,
181- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
182- // const uint cur1_res = fastmodulo(cur0_res,
183- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
184- // const uint cur2 = fastdiv(cur1_res,
185- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
186- // const uint cur3 = fastmodulo(cur1_res,
187- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
188- // const uint curC = layout == 0 ? cur3 : cur0;
189- // const uint curT = layout == 0 ? cur0 : cur1;
190- // const uint curR = layout == 0 ? cur1 : cur2;
191- // const uint curS = layout == 0 ? cur2 : cur3;
192176 const int4 curIdx = inputIndices<layout>(inKOffset, param);
193- // const int curD = posd_ori + curT * param.dilation2; // input w
194- // const int curH = posh_ori + curR * param.dilation1; // input h
195- // const int curW = posw_ori + curS * param.dilation0; // input w
196177 const int curD = posd_ori + curIdx.y * param.dilation2 ; // input w
197178 const int curH = posh_ori + curIdx.z * param.dilation1 ; // input h
198179 const int curW = posw_ori + curIdx.w * param.dilation0 ; // input w
@@ -214,43 +195,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
214195 } else {
215196#pragma unroll
216197 for (int i = 0 ; i < 4 ; ++i){
217- // const uint cur0 = fastdiv(inKOffset + i,
218- // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
219- // const uint cur0_res = fastmodulo(inKOffset + i,
220- // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
221- // const uint cur1 = fastdiv(cur0_res,
222- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
223- // const uint cur1_res = fastmodulo(cur0_res,
224- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
225- // const uint cur2 = fastdiv(cur1_res,
226- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
227- // const uint cur3 = fastmodulo(cur1_res,
228- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
229- // const uint curC = layout == 0 ? cur3 : cur0;
230- // const uint curT = layout == 0 ? cur0 : cur1;
231- // const uint curR = layout == 0 ? cur1 : cur2;
232- // const uint curS = layout == 0 ? cur2 : cur3;
233198 const int4 curIdx = inputIndices<layout>(inKOffset + i, param);
234- // const int curD = posd_ori + curT * param.dilation2; // input w
235- // const int curH = posh_ori + curR * param.dilation1; // input h
236- // const int curW = posw_ori + curS * param.dilation0; // input w
237199 const int curD = posd_ori + curIdx.y * param.dilation2 ; // input w
238200 const int curH = posh_ori + curIdx.z * param.dilation1 ; // input h
239201 const int curW = posw_ori + curIdx.w * param.dilation0 ; // input w
240202 const int curC = curIdx.x ;
241- // const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
242- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
243- // const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
244- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
245- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
246- // const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
247- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
248- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
249- // const uint curC = layout == 0 ? cur2 : cur0;
250- // const uint curR = layout == 0 ? cur0 : cur1;
251- // const uint curS = layout == 0 ? cur1 : cur2;
252- // const int curH = posh_ori + curR * param.d_h; // input h
253- // const int curW = posw_ori + curS * param.d_w; // input w
254203 if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){
255204 int inOffsetTmp = layout == 0 ?
256205 curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
@@ -360,12 +309,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
360309 const uint inKkOffset = innerColA * 4 + crs + BK;
361310#pragma unroll
362311 for (uint offset = 0 ; offset + rowStrideA <= BM; offset += rowStrideA) {
363- // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
364- // const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
365- // const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
366- // const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
367- // int inOffset = n * param.c * param.h * param.w ;
368- const unsigned int gemm_i = bx * BM + innerRowA + offset;
312+ const unsigned int gemm_i = bx * BM + innerRowA + offset;
369313 int n = (ksplit > 0 ) ? fastdiv (gemm_i, param.PQZ_fastdiv ) : z;
370314 const unsigned int npqz_res = fastmodulo (gemm_i, param.PQZ_fastdiv );
371315 const int posd_ori = fastdiv ((ksplit > 0 ) ? npqz_res: gemm_i, param.OHOW_fastdiv ) * param.stride2 - param.padding2 ;
@@ -379,28 +323,10 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
379323 const int curH = posh_ori + curIdx.z * param.dilation1 ; // input h
380324 const int curW = posw_ori + curIdx.w * param.dilation0 ; // input w
381325 const int curC = curIdx.x ;
382- // const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
383- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
384- // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
385- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
386- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
387- // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
388- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
389- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
390- // const uint curC = layout == 0 ? cur2 : cur0;
391- // const uint curR = layout == 0 ? cur0 : cur1;
392- // const uint curS = layout == 0 ? cur1 : cur2;
393-
394- // const int curH = posh_ori + curR * param.d_h; // input h
395- // const int curW = posw_ori + curS * param.d_w; // input w
396326 if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){
397327 int inOffsetTmp = layout == 0 ?
398328 curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
399329 curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
400- // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){
401- // int inOffsetTmp = layout == 0 ?
402- // curH * inChannelOffset + curW * param.c + curC:
403- // curC * inChannelOffset + curH * param.w + curW;
404330 float4 tmp = reinterpret_cast <const float4 *>(&input[inOffset + inOffsetTmp])[0 ];
405331 smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0 ] = tmp.x ;
406332 smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y ;
@@ -414,29 +340,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
414340 } else {
415341#pragma unroll
416342 for (int i = 0 ; i < 4 ; ++i){
417- // const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
418- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
419- // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
420- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
421- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
422- // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
423- // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
424- // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
425- // const uint curC = layout == 0 ? cur2 : cur0;
426- // const uint curR = layout == 0 ? cur0 : cur1;
427- // const uint curS = layout == 0 ? cur1 : cur2;
428-
429- // const int curH = posh_ori + curR * param.d_h; // input h
430- // const int curW = posw_ori + curS * param.d_w; // input w
431343 const int4 curIdx = inputIndices<layout>(inKkOffset + i, param);
432344 const int curD = posd_ori + curIdx.y * param.dilation2 ; // input w
433345 const int curH = posh_ori + curIdx.z * param.dilation1 ; // input h
434346 const int curW = posw_ori + curIdx.w * param.dilation0 ; // input w
435347 const int curC = curIdx.x ;
436- // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
437- // int inOffsetTmp = layout == 0 ?
438- // curH * inChannelOffset + curW * param.c + curC:
439- // curC * inChannelOffset + curH * param.w + curW;
440348 if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){
441349 int inOffsetTmp = layout == 0 ?
442350 curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
@@ -521,7 +429,6 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
521429 const int col = (ksplit > 0 ) ? fastmodulo (gemm_i, param.PQZ_fastdiv ) : gemm_i;
522430 if (n < param.n && row < param.k && col < PQZ){
523431 const uint outOffset = ksplit > 0 ?
524- // z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col :
525432 ((z * param.n + n) * param.k + row) * PQZ + col :
526433 (z * param.k + row) * PQZ + col;
527434 output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
@@ -790,7 +697,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
790697 const unsigned int K = param.c * param.r * param.s * param.t ;
791698 const uint weightKOffset = K; // param.c * param.r * param.s * param.t;
792699 const uint inChannelOffset = param.c * param.w ;
793- const uint inDepthOffset = param.h * param.c * param.w ;
700+ const uint inDepthOffset = param.h * param.c * param.w ;
794701 const uint inNOffset = param.c * param.w * param.h * param.d ;
795702
796703 // loop bounds, constexpr where possible allows for loop unrolling
@@ -854,7 +761,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
854761 if (block_k != num_block_tiles_k){
855762 const half* A_block_gmem = input;
856763 const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
857- tileMemcpyLoadA<BM, BK, NUM_THREADS, 4 >(A_block_gmem, A_gmem_cache_reg, block_k * BK,
764+ tileMemcpyLoadA<BM, BK, NUM_THREADS, 4 >(A_block_gmem, A_gmem_cache_reg, block_k * BK,
858765 inNOffset, inDepthOffset, inChannelOffset, param);
859766 tileMemcpyLoadB<BN, BK, NUM_THREADS, 4 >(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
860767 }
@@ -935,12 +842,9 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
935842 for (int j = 0 ; j < 4 ; ++j){
936843 const uint row = m_idx + subk + i * WN / 2 ;
937844 const uint gemm_i = n_idx + j*32 ;
938- // const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
939- // const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
940845 const int n = fastdiv (gemm_i, param.PQZ_fastdiv );
941846 const int col = fastmodulo (gemm_i, param.PQZ_fastdiv );
942847 if (n < param.n && row < param.k && col < PQZ){
943- // const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
944848 const uint outOffset = (n * param.k + row) * PQZ + col;
945849 uint idx = output_lds_addr + subk + j*32 *BN/2 ;
946850 idx = idx ^ ((idx & 0b1110000000 ) >> 4 );
@@ -1109,19 +1013,15 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
11091013 const uint KW = kernel->ne [0 ]; // kernel_w
11101014 const uint KH = kernel->ne [1 ]; // kernel_h
11111015 const uint KD = kernel->ne [2 ]; // kernel_h
1112- // const uint IC = input->ne[2]; // input_channels
11131016
1114- // const uint OC = kernel->ne[3]; // ouptut_chanles
1115- // const uint B = input->ne[3]; // n_batches
1116-
1117- param_t params = { B,
1118- IC,
1017+ param_t params = { B,
1018+ IC,
11191019 IH, IW, ID,
1120- OC,
1020+ OC,
11211021 KH, KW, KD,
1122- ST_Y, ST_X , ST_Z,
1123- PD_Y, PD_X , PD_Z,
1124- DL_Y, DL_X , DL_Z,
1022+ ST_X, ST_Y , ST_Z,
1023+ PD_X, PD_Y , PD_Z,
1024+ DL_X, DL_Y , DL_Z,
11251025 OH, OW, OD,
11261026 init_fastdiv_values (KW*IC),
11271027 init_fastdiv_values (OW),
0 commit comments