Skip to content

Commit 2357922

Browse files
committed
fixed a bug now all test cases passed
1 parent 3308cce commit 2357922

File tree

3 files changed

+76
-209
lines changed

3 files changed

+76
-209
lines changed

ggml/src/ggml-cuda/conv3d-implicit.cu

Lines changed: 10 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
163163
const uint inKOffset = start_k + innerColA * 4;
164164
#pragma unroll
165165
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
166-
const unsigned int gemm_i = bx * BM + innerRowA + offset;
166+
const unsigned int gemm_i = bx * BM + innerRowA + offset;
167167
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
168168
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
169169
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
@@ -173,26 +173,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
173173
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
174174
int inOffset = n * inNOffset;
175175
if(vec_load){
176-
// const uint cur0 = fastdiv(inKOffset,
177-
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
178-
// const uint cur0_res = fastmodulo(inKOffset,
179-
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
180-
// const uint cur1 = fastdiv(cur0_res,
181-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
182-
// const uint cur1_res = fastmodulo(cur0_res,
183-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
184-
// const uint cur2 = fastdiv(cur1_res,
185-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
186-
// const uint cur3 = fastmodulo(cur1_res,
187-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
188-
// const uint curC = layout == 0 ? cur3 : cur0;
189-
// const uint curT = layout == 0 ? cur0 : cur1;
190-
// const uint curR = layout == 0 ? cur1 : cur2;
191-
// const uint curS = layout == 0 ? cur2 : cur3;
192176
const int4 curIdx = inputIndices<layout>(inKOffset, param);
193-
// const int curD = posd_ori + curT * param.dilation2; // input w
194-
// const int curH = posh_ori + curR * param.dilation1; // input h
195-
// const int curW = posw_ori + curS * param.dilation0; // input w
196177
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
197178
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
198179
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
@@ -214,43 +195,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
214195
} else {
215196
#pragma unroll
216197
for (int i = 0; i < 4; ++i){
217-
// const uint cur0 = fastdiv(inKOffset + i,
218-
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
219-
// const uint cur0_res = fastmodulo(inKOffset + i,
220-
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
221-
// const uint cur1 = fastdiv(cur0_res,
222-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
223-
// const uint cur1_res = fastmodulo(cur0_res,
224-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
225-
// const uint cur2 = fastdiv(cur1_res,
226-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
227-
// const uint cur3 = fastmodulo(cur1_res,
228-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
229-
// const uint curC = layout == 0 ? cur3 : cur0;
230-
// const uint curT = layout == 0 ? cur0 : cur1;
231-
// const uint curR = layout == 0 ? cur1 : cur2;
232-
// const uint curS = layout == 0 ? cur2 : cur3;
233198
const int4 curIdx = inputIndices<layout>(inKOffset + i, param);
234-
// const int curD = posd_ori + curT * param.dilation2; // input w
235-
// const int curH = posh_ori + curR * param.dilation1; // input h
236-
// const int curW = posw_ori + curS * param.dilation0; // input w
237199
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
238200
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
239201
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
240202
const int curC = curIdx.x;
241-
// const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
242-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
243-
// const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
244-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
245-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
246-
// const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
247-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
248-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
249-
// const uint curC = layout == 0 ? cur2 : cur0;
250-
// const uint curR = layout == 0 ? cur0 : cur1;
251-
// const uint curS = layout == 0 ? cur1 : cur2;
252-
// const int curH = posh_ori + curR * param.d_h; // input h
253-
// const int curW = posw_ori + curS * param.d_w; // input w
254203
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){
255204
int inOffsetTmp = layout == 0 ?
256205
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
@@ -360,12 +309,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
360309
const uint inKkOffset = innerColA * 4 + crs + BK;
361310
#pragma unroll
362311
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
363-
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
364-
// const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
365-
// const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
366-
// const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
367-
// int inOffset = n * param.c * param.h * param.w ;
368-
const unsigned int gemm_i = bx * BM + innerRowA + offset;
312+
const unsigned int gemm_i = bx * BM + innerRowA + offset;
369313
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
370314
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
371315
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2;
@@ -379,28 +323,10 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
379323
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
380324
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
381325
const int curC = curIdx.x;
382-
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
383-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
384-
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
385-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
386-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
387-
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
388-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
389-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
390-
// const uint curC = layout == 0 ? cur2 : cur0;
391-
// const uint curR = layout == 0 ? cur0 : cur1;
392-
// const uint curS = layout == 0 ? cur1 : cur2;
393-
394-
// const int curH = posh_ori + curR * param.d_h; // input h
395-
// const int curW = posw_ori + curS * param.d_w; // input w
396326
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){
397327
int inOffsetTmp = layout == 0 ?
398328
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
399329
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
400-
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){
401-
// int inOffsetTmp = layout == 0 ?
402-
// curH * inChannelOffset + curW * param.c + curC:
403-
// curC * inChannelOffset + curH * param.w + curW;
404330
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
405331
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
406332
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
@@ -414,29 +340,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
414340
} else {
415341
#pragma unroll
416342
for (int i = 0; i < 4; ++i){
417-
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
418-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
419-
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
420-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
421-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
422-
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
423-
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
424-
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
425-
// const uint curC = layout == 0 ? cur2 : cur0;
426-
// const uint curR = layout == 0 ? cur0 : cur1;
427-
// const uint curS = layout == 0 ? cur1 : cur2;
428-
429-
// const int curH = posh_ori + curR * param.d_h; // input h
430-
// const int curW = posw_ori + curS * param.d_w; // input w
431343
const int4 curIdx = inputIndices<layout>(inKkOffset + i, param);
432344
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
433345
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
434346
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
435347
const int curC = curIdx.x;
436-
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
437-
// int inOffsetTmp = layout == 0 ?
438-
// curH * inChannelOffset + curW * param.c + curC:
439-
// curC * inChannelOffset + curH * param.w + curW;
440348
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){
441349
int inOffsetTmp = layout == 0 ?
442350
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
@@ -521,7 +429,6 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
521429
const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i;
522430
if (n < param.n && row < param.k && col < PQZ){
523431
const uint outOffset = ksplit > 0 ?
524-
// z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col :
525432
((z * param.n + n) * param.k + row) * PQZ + col :
526433
(z * param.k + row) * PQZ + col;
527434
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
@@ -790,7 +697,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
790697
const unsigned int K = param.c * param.r * param.s * param.t;
791698
const uint weightKOffset = K; //param.c * param.r * param.s * param.t;
792699
const uint inChannelOffset = param.c * param.w;
793-
const uint inDepthOffset = param.h * param.c * param.w;
700+
const uint inDepthOffset = param.h * param.c * param.w;
794701
const uint inNOffset = param.c * param.w * param.h * param.d;
795702

796703
// loop bounds, constexpr where possible allows for loop unrolling
@@ -854,7 +761,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
854761
if (block_k != num_block_tiles_k){
855762
const half* A_block_gmem = input;
856763
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
857-
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
764+
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
858765
inNOffset, inDepthOffset, inChannelOffset, param);
859766
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
860767
}
@@ -935,12 +842,9 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
935842
for (int j = 0; j < 4; ++j){
936843
const uint row = m_idx + subk + i * WN / 2;
937844
const uint gemm_i = n_idx + j*32;
938-
// const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
939-
// const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
940845
const int n = fastdiv(gemm_i, param.PQZ_fastdiv);
941846
const int col = fastmodulo(gemm_i, param.PQZ_fastdiv);
942847
if(n < param.n && row < param.k && col < PQZ){
943-
// const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
944848
const uint outOffset = (n * param.k + row) * PQZ + col;
945849
uint idx = output_lds_addr + subk + j*32*BN/2;
946850
idx = idx ^ ((idx & 0b1110000000) >> 4);
@@ -1109,19 +1013,15 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
11091013
const uint KW = kernel->ne[0]; // kernel_w
11101014
const uint KH = kernel->ne[1]; // kernel_h
11111015
const uint KD = kernel->ne[2]; // kernel_h
1112-
// const uint IC = input->ne[2]; // input_channels
11131016

1114-
// const uint OC = kernel->ne[3]; // ouptut_chanles
1115-
// const uint B = input->ne[3]; // n_batches
1116-
1117-
param_t params = { B,
1118-
IC,
1017+
param_t params = { B,
1018+
IC,
11191019
IH, IW, ID,
1120-
OC,
1020+
OC,
11211021
KH, KW, KD,
1122-
ST_Y, ST_X, ST_Z,
1123-
PD_Y, PD_X, PD_Z,
1124-
DL_Y, DL_X, DL_Z,
1022+
ST_X, ST_Y, ST_Z,
1023+
PD_X, PD_Y, PD_Z,
1024+
DL_X, DL_Y, DL_Z,
11251025
OH, OW, OD,
11261026
init_fastdiv_values(KW*IC),
11271027
init_fastdiv_values(OW),

0 commit comments

Comments
 (0)