From ab15f6cd5fede5726c48319108955b58a4a8ffe5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 1 Nov 2025 20:08:15 -0400 Subject: [PATCH 01/19] use conv2d_implicit as template; add conv3d parameters --- ggml/src/ggml-cuda/conv3d-implicit.cu | 1062 ++++++++++++++++++++++++ ggml/src/ggml-cuda/conv3d-implicit.cuh | 353 ++++++++ 2 files changed, 1415 insertions(+) create mode 100644 ggml/src/ggml-cuda/conv3d-implicit.cu create mode 100644 ggml/src/ggml-cuda/conv3d-implicit.cuh diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu new file mode 100644 index 0000000000000..640366b80ae5b --- /dev/null +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -0,0 +1,1062 @@ +// #include +#include "ggml.h" +#include "common.cuh" +#include "convert.cuh" +#include "conv3d-implicit.cuh" + + +typedef unsigned int uint; +constexpr uint WARPSIZE = 32; +#define CUDA_NCHW_2_NHWC_TILE_DIM 32 +#define CUDA_NCHW_2_NHWC_BLOCK_NM 8 +#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 + + +//currently not use; in future for split-k kernels +// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { +// const int row = blockIdx.x; +// const int col = threadIdx.x; + +// float sum = 0.0f; +// if (row * blockDim.x + col < ncols) { +// for (int i = 0; i < nrows; ++i){ +// sum += x[i * ncols + row * blockDim.x + col]; +// } +// dst[row * blockDim.x + col] = sum; +// } +// } + +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + int x = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; + int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + + __shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM]; + + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ + + const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + } + } + __syncthreads(); + + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(ty + j < ne01 && tx < ne00){ + const int col = (threadIdx.y+j) ^ threadIdx.x; + dst[imat*n + (ty+j)*ne00 + tx] = ggml_cuda_cast(tile[threadIdx.x][col]); + } + } + } +} + +template +static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { + + __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ? + sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; + T *smemweight = reinterpret_cast(smem); + float *smeminput = reinterpret_cast(smem + 2 * BK * (BN+PAD) * sizeof(T)); + + const uint tx = threadIdx.x; + const uint bx = blockIdx.x; + const uint by = blockIdx.y; + + const uint PQ = param.Oh * param.Ow; + + // Warp tile + const uint lane_id = tx % WARPSIZE; + const uint warp_id = tx / WARPSIZE; + const int mma_tid_x = warp_id / (BN / WN); + const int mma_tid_y = warp_id % (BN / WN); + + // size of the warp subtile + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; // 64/2=32 + constexpr uint WSUBN = WN / WNITER; // 32/2=16 + + // Placement of the thread in the warp subtile + const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4) + const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4 + + int z = blockIdx.z; + + int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; + int weightKOffset = param.c * param.r * param.s; + + const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const uint start_k = (ksplit > 0)? z * ks: 0; + const uint end_k = min(start_k + ks, weightKOffset); + + int write_flag = 1; + T weight_frag[2][WNITER * TN]; + float input_frag[2][WMITER * TM] = {0.f}; + float output_frag[WMITER * TM * WNITER * TN] = {0.f}; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = tx / (BK / 4); + const uint innerColA = tx % (BK / 4); + constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; + +// ldg + const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + if(vec_load){ + if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){ + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + smemweight[weight_sts_addr + offset + 0] = tmp.x; + smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + }else{ // read 4 halves + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + const half *val = reinterpret_cast(&tmp); + smemweight[weight_sts_addr + offset + 0] = val[0]; + smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i){ + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + }else{ +#pragma unroll + for (int i = 0; i < 4; ++i){ + if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){ + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i]; + } else { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } + + + const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + int inOffset = n * param.c * param.h * param.w ; + if(vec_load){ + const uint cur0 = fastdiv(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i){ + const uint cur0 = fastdiv(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; + } else { + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } + } + } + __syncthreads(); + + // lds + const uint input_lds_addr = mma_tid_x * WM; +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + + const uint weight_lds_addr = mma_tid_y * WN; +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + + for (int crs = start_k; crs < end_k; crs += BK) { + + int load_flag = write_flag ^ 1; +#pragma unroll + for (int subcrs = 0; subcrs < BK - 1; ++subcrs) + { + +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK + + (subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll + for (uint i = 0; i < TM; ++i) + input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK + + (subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; + + // execute warptile matmul +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]); + } + } + } + } + } + // ldg +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + if(vec_load){ + if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + } else { + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + const half *val = reinterpret_cast(&tmp); + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + }else{ +#pragma unroll + for (int i = 0; i < 4; ++i){ + if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){ + // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i]; + } else { + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + int inOffset = n * param.c * param.h * param.w ; + if(vec_load){ + const uint cur0 = fastdiv(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; + + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i){ + const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; + + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; + } else { + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } + } + } + __syncthreads(); + + write_flag ^= 1; + +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * (BM+PAD) * BK + + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK + + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[1][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[1][wSubColIdx * TN + resIdxN]); + } + } + } + } + } + + // reuse smem + float *smemoutput = reinterpret_cast(smem); + + const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE + + threadColInWarp * TN * WSUBM + threadRowInWarp * TM; + const uint m_idx = by * BN + mma_tid_y * WN; + const uint n_idx = bx * BM + mma_tid_x * WM; + +#pragma unroll + for (int i = 0; i < WMITER; ++i) + { +#pragma unroll + for (int j = 0; j < WNITER; ++j) + { + __syncthreads(); + +#pragma unroll + for (int subi = 0; subi < TM; ++subi) + { +#pragma unroll + for (int subj = 0; subj < TN; ++subj) + { + // output sts + smemoutput[output_sts_addr + subj * WSUBM + subi] = + output_frag[(i * TM + subi) * (WNITER * TN) + j * TN + subj]; + } + } + __syncthreads(); +#pragma unroll + for (int subk = 0; subk < TM * TN; ++subk){ + const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; + const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; + const int n = (ksplit > 0) ? gemm_i / PQ : z; + const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; + if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + const uint outOffset = ksplit > 0 ? + z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col : + z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + } + } + } + } +} + + + +template +__device__ __forceinline__ void ldmatrix_a( + const half* src, + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] +){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); + uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][0]), "=r"(reg_[0][0][1]), "=r"(reg_[1][0][0]), "=r"(reg_[1][0][1]) + : "r"(src_addr) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0][0]), "=r"(reg_[2][0][1]), "=r"(reg_[3][0][0]), "=r"(reg_[3][0][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][0][0]), "=r"(reg_[4][0][1]), "=r"(reg_[5][0][0]), "=r"(reg_[5][0][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][0][0]), "=r"(reg_[6][0][1]), "=r"(reg_[7][0][0]), "=r"(reg_[7][0][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b10000; + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + : "r"(src_addr) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b110000; + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + : "r"(src_addr) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + src_addr ^= 0b10000; + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ void ldmatrix_b( + const half* src, + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); + + uint32_t (®_) [4][8] = reinterpret_cast(reg); + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); + uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "r"(src_addr + 32 * smem_stride_) + ); + + src_addr ^= 0b10000; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "r"(src_addr + 32 * smem_stride_) + ); + + src_addr ^= 0b110000; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "r"(src_addr + 32 * smem_stride_) + ); + + src_addr ^= 0b10000; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif +} + +template +static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, + const half * __restrict__ kernel, + half * __restrict__ output, + const param_t param) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + +constexpr unsigned int MMA_M = 16; +constexpr unsigned int MMA_N = 8; + + + const unsigned int K = param.c * param.r * param.s; + const uint inChannelOffset = param.c * param.w; + const uint weightKOffset = param.c * param.r * param.s; + + // loop bounds, constexpr where possible allows for loop unrolling + constexpr unsigned int mma_tiles_per_warp_k = 4; + constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; + constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; + const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; + + // calculate block/warp indices + const unsigned int block_m = blockIdx.y; + const unsigned int block_n = blockIdx.x; + const unsigned int warp_m = threadIdx.y; + const unsigned int warp_n = threadIdx.x / 32; + + // double buffering + extern __shared__ half shmem[]; + half* A_block_smem = shmem; + half* B_block_smem = &shmem[BM * BK]; + constexpr int BUFFER_SIZE = BM * BK + BK * BN; + + // declare register storage + // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together + uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; + + // convenience cast to half for register storage + half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); + + // accumulators start at 0 + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ + acc_register_[mma_m][mma_n][0] = 0; + acc_register_[mma_m][mma_n][1] = 0; + acc_register_[mma_m][mma_n][2] = 0; + acc_register_[mma_m][mma_n][3] = 0; + } + } + + static_assert(BM == 256); + static_assert(BN == 256); + static_assert(BK == 32); + static_assert(NUM_THREADS == 256); + float4 A_gmem_cache_reg[4]; + float4 B_gmem_cache_reg[4]; + + // prefetch the first block tile of A,B into shared memory + + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + block_n * BN * weightKOffset; + tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); + + int offset_direction = 1; + + for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){ + __syncthreads(); + + if (block_k != num_block_tiles_k){ + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + } + half* A_warp_tile = A_block_smem + (warp_m * WM * BK); + half* B_warp_tile = B_block_smem + (warp_n * WN * BK); + + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); + + // outer product between mma tiles +#pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){ +#pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ +#pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); + } + } + } + + + if (block_k != num_block_tiles_k) + { + // switch smem buffers each iteration + A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction; + B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; + offset_direction = -1 * offset_direction; + + tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + } + } + + // reuse smem + half *smemoutput = shmem; + const uint lane_id = threadIdx.x % WARPSIZE; + const uint mma_row = lane_id / 4; + const uint mma_col = lane_id % 4; + const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2; + const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; + const uint m_idx = block_n * BN + warp_n * WN; + const uint n_idx = block_m * BM + warp_m * WM + lane_id; + +#pragma unroll + for (int i = 0; i < 2; ++i) + { + __syncthreads(); + + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) + { + uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); + uint idx = output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + idx = idx ^ ((idx & 0b1110000000) >> 4); + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); + dst_ptr[0] = reg_[0]; + dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); + dst_ptr[0] = reg_[1]; + } + } + __syncthreads(); + +#pragma unroll + for (int subk = 0; subk < WN / 2; ++subk){ + for (int j = 0; j < 4; ++j){ + const uint row = m_idx + subk + i * WN / 2; + const uint gemm_i = n_idx + j*32; + const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + if(n < param.n && row < param.k && col < param.Oh * param.Ow){ + const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + uint idx = output_lds_addr + subk + j*32*BN/2; + idx = idx ^ ((idx & 0b1110000000) >> 4); + output[outOffset] = smemoutput[idx]; + } + } + } + } +#else + GGML_UNUSED(input); + GGML_UNUSED(kernel); + GGML_UNUSED(output); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +#define NUM_VARIANTS 4 + +/* + conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256] + conv_shapes[][1]: ne_input=[96,128,512,1],ne_kernel=[3,3,512,512] + conv_shapes[][2]: ne_input=[192,256,512,1git diff],ne_kernel=[3,3,512,512] +*/ +constexpr static int conv_shapes[][NUM_VARIANTS] = { + { 128, 128, 128, 256 }, // BM + { 256, 128, 256, 128 }, // BN + { 8, 8, 8, 8 }, // BK + { 128, 64, 32, 128 }, // WM + { 32, 32 , 256, 32 }, // WN + { 2, 2, 1, 1 }, // WNITER + { 8, 4, 4, 4 }, // TM + { 8, 4, 8, 8 }, // TN + { 256, 256, 128, 256} // NUM_THREADS +}; + +template +static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { + + const uint BM = conv_shapes[0][CONV_SHAPE]; + const uint BN = conv_shapes[1][CONV_SHAPE]; + const uint BK = conv_shapes[2][CONV_SHAPE]; + const uint WM = conv_shapes[3][CONV_SHAPE]; + const uint WN = conv_shapes[4][CONV_SHAPE]; + const uint WNITER = conv_shapes[5][CONV_SHAPE]; + const uint TM = conv_shapes[6][CONV_SHAPE]; + const uint TN = conv_shapes[7][CONV_SHAPE]; + const uint NUM_THREADS = conv_shapes[8][CONV_SHAPE]; + int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number + int blocky = (P.k + BN-1) / BN; // blocky number + int blockz = P.n; // blockz number + int thready = 1; // thready number per block + int threadz = 1; // threadz number per block + dim3 thblock(NUM_THREADS, thready, threadz); + dim3 grid(blockx, blocky, blockz); + + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { + + if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { + + int id = ggml_cuda_get_device(); + + int64_t ne = P.c * P.h * P.w * P.n; + int64_t ne00 = P.c; + int64_t ne01 = P.h * P.w; + ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + + dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + + ne = P.c * P.r * P.s * P.k; + ne01 = P.r * P.s; + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + + const half *X_H = input_f16.get(); + const half *K_H = kernel_f16.get(); + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + + constexpr unsigned int BM_dim = 256; + constexpr unsigned int BN_dim = 256; + constexpr unsigned int BK_dim = 32; + + constexpr unsigned int WARPS_PER_BLOCK_M = 2; + constexpr unsigned int WARPS_PER_BLOCK_N = 4; + constexpr unsigned int WARPS_PER_BLOCK_K = 4; + + constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; + const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + + cudaFuncSetAttribute(conv3d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + conv3d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + } else{ + conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); + } + +} + +static void conv3d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { + conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); + GGML_UNUSED(ctx); + GGML_UNUSED(cc); +} + +void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + + cudaStream_t st = ctx.stream(); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + + const int32_t * p = (const int32_t *) dst->op_params; + const uint ST_X = p[0]; // stride_x + const uint ST_Y = p[1]; // stride_y + const uint ST_Z = p[2]; // stride_y + const uint PD_X = p[3]; // padding_x + const uint PD_Y = p[4]; // padding_y + const uint PD_Z = p[5]; // padding_y + const uint DL_X = p[6]; // dilation_x + const uint DL_Y = p[7]; // dilation_y + const uint DL_Z = p[8]; // dilation_y + const uint IC = p[9]; // number of channels + const uint B = p[10]; // batch number + const uint OC = p[11]; // output channels + + GGML_ASSERT(p[12] == false); + + const uint IW = input->ne[0]; // input_w + const uint IH = input->ne[1]; // input_h + const uint ID = input->ne[2]; // input_h + const uint OW = dst->ne[0]; // output_w + const uint OH = dst->ne[1]; // output_h + const uint OD = dst->ne[2]; // output_h + const uint KW = kernel->ne[0]; // kernel_w + const uint KH = kernel->ne[1]; // kernel_h + const uint KD = kernel->ne[2]; // kernel_h + // const uint IC = input->ne[2]; // input_channels + + // const uint OC = kernel->ne[3]; // ouptut_chanles + // const uint B = input->ne[3]; // n_batches + + param_t params = { B, + IC, + IH, IW, ID, + OC, + KH, KW, KD, + ST_Y, ST_X, ST_Z, + PD_Y, PD_X, PD_Z, + DL_Y, DL_X, DL_Z, + OH, OW, OD, + init_fastdiv_values(KW*IC), + init_fastdiv_values(OW), + init_fastdiv_values(IC), + init_fastdiv_values(KW*KH), + init_fastdiv_values(KW), + init_fastdiv_values(OW*OH)}; + + if (kernel->type == GGML_TYPE_F16) { + conv3d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); + } else { + conv3d_implicit_cuda_f32(ctx, X_D, K_D, Y_D, cc, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh new file mode 100644 index 0000000000000..d550ec07c8e10 --- /dev/null +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -0,0 +1,353 @@ +#pragma once +#include "common.cuh" + +typedef struct{ + unsigned int n; //batch size + unsigned int c; //number if channels + unsigned int h; //height + unsigned int w; //width + unsigned int d; //depth + unsigned int k; //number of filters + unsigned int r; //filter height + unsigned int s; //filter width + unsigned int t; //filter depth + unsigned int stride0; //stride width + unsigned int stride1; //stride height + unsigned int stride2; //stride depth + unsigned int padding0; //padding width + unsigned int padding1; //padding height + unsigned int padding2; //padding depth + unsigned int dilation0; //dilation width + unsigned int dilation1; //dilation height + unsigned int dilation2; //dilation depth + unsigned int Oh; //output height + unsigned int Ow; //output width + unsigned int Od; //output depth + uint3 SC_fastdiv; + uint3 OW_fastdiv; + uint3 C_fastdiv; + uint3 RS_fastdiv; + uint3 S_fastdiv; + uint3 OHOW_fastdiv; +} param_t; + + +// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel +template +__device__ __forceinline__ void tileMemcpySwizzleB( + const half* src, + half* dst, + const unsigned int src_stride, + param_t param +){ +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; + + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++){ + // apply swizzle to the dst index + const unsigned int src_index = thread_row * src_stride + thread_col * 8; + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; + }else{ // read 4 halves + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(src_stride); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ void tileMemcpySwizzleA( + const half* src, + half* dst, + // const unsigned int src_stride, + const unsigned int inChannelOffset, + param_t param +) +{ +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; + + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++){ + unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + unsigned int inOffset = n * param.c * param.h * param.w; + const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + int curH = posh_ori + curR * param.d_h; // input h + int curW = posw_ori + curS * param.d_w; // input w + // apply swizzle to the dst index + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + curR < param.r && curS < param.s && curC < param.c){ + const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + } else{ + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(inChannelOffset); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ void tileMemcpyLoadA( + const half* src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + // const unsigned int src_stride, + const unsigned int block_k, + const unsigned int inChannelOffset, + param_t param +){ +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++){ + unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + unsigned int inOffset = n * param.c * param.h * param.w; + const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + int curH = posh_ori + curR * param.d_h; // input h + int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + curR < param.r && curS < param.s && curC < param.c){ + const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + } else{ + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst_reg); + GGML_UNUSED(block_k); + GGML_UNUSED(inChannelOffset); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +template +__device__ __forceinline__ void tileMemcpyLoadB( + const half* src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int block_k, + const unsigned int src_stride, + param_t param +){ +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++){ + const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + dst_reg[i] = reinterpret_cast(&src[src_index])[0]; + }else{ // read 4 halves + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst_reg); + GGML_UNUSED(block_k); + GGML_UNUSED(src_stride); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +// same as above but without the swizzle + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ void tileMemcpySwizzleStore( + const float4 (&src_reg)[ELEMENTS_PER_THREAD], + half* dst +) +{ +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; + + // reinterpret input/output as float4 + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // apply swizzle to the dst index + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + dst_float4[dst_index] = src_reg[i]; + thread_row += ROW_STEP; + } +#else + GGML_UNUSED(src_reg); + GGML_UNUSED(dst); + NO_DEVICE_CODE; +#endif +} + +__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { + uint32_t address; + asm("{\n\t" + " .reg .u64 u64addr;\n\t" + " cvta.to.shared.u64 u64addr, %1;\n\t" + " cvt.u32.u64 %0, u64addr;\n\t" + "}" + : "=r"(address) + : "l"(pointer)); + return address; +} + + +#define CUDA_CONV3D_IMPLICT_BLOCK_SIZE 256 +void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 52455b8a6d25cea91c48c99cfa3831f30c7843eb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 1 Nov 2025 22:01:00 -0400 Subject: [PATCH 02/19] WIP: updating indices for input and kernel; enable OP_CONV_3D for cuda backend --- ggml/src/ggml-cuda/conv3d-implicit.cu | 28 +++++++++++++++------------ ggml/src/ggml-cuda/ggml-cuda.cu | 5 +++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 640366b80ae5b..c6aa7e0749930 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -80,7 +80,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint bx = blockIdx.x; const uint by = blockIdx.y; - const uint PQ = param.Oh * param.Ow; + const uint PQZ = param.Oh * param.Ow * param.Oz; // Warp tile const uint lane_id = tx % WARPSIZE; @@ -100,7 +100,8 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, int z = blockIdx.z; int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; - int weightKOffset = param.c * param.r * param.s; + int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w; + int weightKOffset = param.c * param.r * param.s * param.t; const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; const uint start_k = (ksplit > 0)? z * ks: 0; @@ -159,11 +160,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; + const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ; + const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv); + const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; + const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; + int inOffset = n * param.c * param.h * param.w * param.d; if(vec_load){ const uint cur0 = fastdiv(start_k + innerColA * 4, layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset @@ -176,12 +179,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint curC = layout == 0 ? cur2 : cur0; const uint curR = layout == 0 ? cur0 : cur1; const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ + const int curD = posd_ori + curT * param.dilation2; // input w + const int curH = posh_ori + curR * param.dilation1; // input h + const int curW = posw_ori + curS * param.dilation0; // input w + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[input_sts_addr + offset + 0] = tmp.x; smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 61a8f1df87de1..7f16a3b82637a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,6 +13,7 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv3d-implicit.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2629,6 +2630,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_2D: ggml_cuda_op_conv2d(ctx, dst); break; + case GGML_OP_CONV_3D: + ggml_cuda_op_conv3d_implicit(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -4041,6 +4045,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: From 0a64ea8ff84add8bafc7464fb0018d223ae80e79 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 10:34:03 -0500 Subject: [PATCH 03/19] WIP: build ok --- ggml/src/ggml-cuda/conv3d-implicit.cu | 379 +++++++++++++++---------- ggml/src/ggml-cuda/conv3d-implicit.cuh | 21 +- 2 files changed, 248 insertions(+), 152 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index c6aa7e0749930..00aaa568af974 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -62,6 +62,29 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } +template +__device__ int4 inputIndices(const uint kidx, param_t param) { + + const uint cur0 = fastdiv(kidx, + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + const uint cur0_res = fastmodulo(kidx, + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + const uint cur1 = fastdiv(cur0_res, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + const uint cur1_res = fastmodulo(cur0_res, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + const uint cur2 = fastdiv(cur1_res, + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur3 = fastmodulo(cur1_res, + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur3 : cur0; + const uint curT = layout == 0 ? cur0 : cur1; + const uint curR = layout == 0 ? cur1 : cur2; + const uint curS = layout == 0 ? cur2 : cur3; + return make_int4(curC, curT, curR, curS); + +} + template 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; const uint start_k = (ksplit > 0)? z * ks: 0; @@ -158,31 +182,44 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; + const uint inKOffset = start_k + innerColA * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; - const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ; - const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2; - const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv); + const unsigned int gemm_i = bx * BM + innerRowA + offset; + // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; + int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; + const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); + const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv); const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; - int inOffset = n * param.c * param.h * param.w * param.d; + int inOffset = n * inNOffset; if(vec_load){ - const uint cur0 = fastdiv(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curD = posd_ori + curT * param.dilation2; // input w - const int curH = posh_ori + curR * param.dilation1; // input h - const int curW = posw_ori + curS * param.dilation0; // input w - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){ + // const uint cur0 = fastdiv(inKOffset, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur0_res = fastmodulo(inKOffset, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur1 = fastdiv(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur1_res = fastmodulo(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur2 = fastdiv(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur3 = fastmodulo(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur3 : cur0; + // const uint curT = layout == 0 ? cur0 : cur1; + // const uint curR = layout == 0 ? cur1 : cur2; + // const uint curS = layout == 0 ? cur2 : cur3; + const int4 curIdx = inputIndices(inKOffset, param); + // const int curD = posd_ori + curT * param.dilation2; // input w + // const int curH = posh_ori + curR * param.dilation1; // input h + // const int curW = posw_ori + curS * param.dilation0; // input w + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset < end_k){ int inOffsetTmp = layout == 0 ? curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; @@ -199,23 +236,47 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ + // const uint cur0 = fastdiv(inKOffset + i, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur0_res = fastmodulo(inKOffset + i, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur1 = fastdiv(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur1_res = fastmodulo(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur2 = fastdiv(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur3 = fastmodulo(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur3 : cur0; + // const uint curT = layout == 0 ? cur0 : cur1; + // const uint curR = layout == 0 ? cur1 : cur2; + // const uint curS = layout == 0 ? cur2 : cur3; + const int4 curIdx = inputIndices(inKOffset + i, param); + // const int curD = posd_ori + curT * param.dilation2; // input w + // const int curH = posh_ori + curR * param.dilation1; // input h + // const int curW = posw_ori + curS * param.dilation0; // input w + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + // const uint cur0 = fastdiv(start_k + innerColA * 4 + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + // const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur2 : cur0; + // const uint curR = layout == 0 ? cur0 : cur1; + // const uint curS = layout == 0 ? cur1 : cur2; + // const int curH = posh_ori + curR * param.d_h; // input h + // const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; } else { smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; @@ -242,6 +303,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; + // main block k loop for (int crs = start_k; crs < end_k; crs += BK) { int load_flag = write_flag ^ 1; @@ -317,33 +379,50 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } } } + const uint inKkOffset = innerColA * 4 + crs + BK; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; + // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + // const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + // const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + // const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + // int inOffset = n * param.c * param.h * param.w ; + const unsigned int gemm_i = bx * BM + innerRowA + offset; + int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; + const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); + const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv); + const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; + const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; + int inOffset = n * inNOffset; if(vec_load){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + const int4 curIdx = inputIndices(inKkOffset, param); + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + // const uint cur0 = fastdiv(innerColA * 4 + crs + BK, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur2 : cur0; + // const uint curR = layout == 0 ? cur0 : cur1; + // const uint curS = layout == 0 ? cur1 : cur2; + + // const int curH = posh_ori + curR * param.d_h; // input h + // const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){ + // int inOffsetTmp = layout == 0 ? + // curH * inChannelOffset + curW * param.c + curC: + // curC * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; @@ -357,24 +436,33 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + // const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur2 : cur0; + // const uint curR = layout == 0 ? cur0 : cur1; + // const uint curS = layout == 0 ? cur1 : cur2; + + // const int curH = posh_ori + curR * param.d_h; // input h + // const int curW = posw_ori + curS * param.d_w; // input w + const int4 curIdx = inputIndices(inKkOffset + i, param); + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + // int inOffsetTmp = layout == 0 ? + // curH * inChannelOffset + curW * param.c + curC: + // curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; } else { smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; @@ -448,15 +536,16 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, __syncthreads(); #pragma unroll for (int subk = 0; subk < TM * TN; ++subk){ + // output: [N*OC, OD, OH, OW] const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; - const int n = (ksplit > 0) ? gemm_i / PQ : z; - const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; - if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + const int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; + const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i; + if (n < param.n && row < param.k && col < PQZ){ const uint outOffset = ksplit > 0 ? - z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + - row * param.Oh * param.Ow + col : - z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + // z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col : + ((z * param.n + n) * param.k + row) * PQZ + col : + (z * param.k + row) * PQZ + col; output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; } } @@ -464,7 +553,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } } - +#if 0 template __device__ __forceinline__ void ldmatrix_a( @@ -885,6 +974,7 @@ constexpr unsigned int MMA_N = 8; #endif } +#endif #define NUM_VARIANTS 4 @@ -925,70 +1015,70 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { - - int id = ggml_cuda_get_device(); - - int64_t ne = P.c * P.h * P.w * P.n; - int64_t ne00 = P.c; - int64_t ne01 = P.h * P.w; - ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); - - dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); - NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); - - ne = P.c * P.r * P.s * P.k; - ne01 = P.r * P.s; - ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); - dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - - const half *X_H = input_f16.get(); - const half *K_H = kernel_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); - - constexpr unsigned int BM_dim = 256; - constexpr unsigned int BN_dim = 256; - constexpr unsigned int BK_dim = 32; - - constexpr unsigned int WARPS_PER_BLOCK_M = 2; - constexpr unsigned int WARPS_PER_BLOCK_N = 4; - constexpr unsigned int WARPS_PER_BLOCK_K = 4; - - constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; - constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; - constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; - const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; - const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; - constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; - constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; - constexpr unsigned int NumThreads = ThreadsM * ThreadsN; - const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - - cudaFuncSetAttribute(conv3d_implicit_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - dim3 gridDim(BlocksN, BlocksM); - dim3 blockDim(ThreadsN, ThreadsM); - - conv3d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); - } else{ + // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { + + // int id = ggml_cuda_get_device(); + + // int64_t ne = P.c * P.h * P.w * P.n; + // int64_t ne00 = P.c; + // int64_t ne01 = P.h * P.w; + // ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + + // dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + // dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + // NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + + // ne = P.c * P.r * P.s * P.k; + // ne01 = P.r * P.s; + // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + // dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + + // const half *X_H = input_f16.get(); + // const half *K_H = kernel_f16.get(); + // ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + + // constexpr unsigned int BM_dim = 256; + // constexpr unsigned int BN_dim = 256; + // constexpr unsigned int BK_dim = 32; + + // constexpr unsigned int WARPS_PER_BLOCK_M = 2; + // constexpr unsigned int WARPS_PER_BLOCK_N = 4; + // constexpr unsigned int WARPS_PER_BLOCK_K = 4; + + // constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + // constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + // constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + // const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; + // const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + // constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + // constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + // constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + // const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + + // cudaFuncSetAttribute(conv3d_implicit_kernel, + // cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + // dim3 gridDim(BlocksN, BlocksM); + // dim3 blockDim(ThreadsN, ThreadsM); + + // conv3d_implicit_kernel + // <<>>(X_H, K_H, Y_H.get(), P); + // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + // to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + // } else{ conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); - } + // } } @@ -1056,7 +1146,10 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * init_fastdiv_values(IC), init_fastdiv_values(KW*KH), init_fastdiv_values(KW), - init_fastdiv_values(OW*OH)}; + init_fastdiv_values(OW*OH), + init_fastdiv_values(OW*OH*OD), + init_fastdiv_values(KW*KH*IC), + init_fastdiv_values(KW*KH*KD)}; if (kernel->type == GGML_TYPE_F16) { conv3d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index d550ec07c8e10..4e14c15cd2235 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -29,9 +29,13 @@ typedef struct{ uint3 RS_fastdiv; uint3 S_fastdiv; uint3 OHOW_fastdiv; + uint3 PQZ_fastdiv; + uint3 RSC_fastdiv; + uint3 TRS_fastdiv; } param_t; + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template @@ -131,14 +135,14 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; unsigned int inOffset = n * param.c * param.h * param.w; const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w + int curH = posh_ori + curR * param.dilation1; // input h + int curW = posw_ori + curS * param.dilation0; // input w // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); @@ -197,14 +201,14 @@ __device__ __forceinline__ void tileMemcpyLoadA( unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; unsigned int inOffset = n * param.c * param.h * param.w; const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w + int curH = posh_ori + curR * param.dilation1; // input h + int curW = posw_ori + curS * param.dilation0; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curR < param.r && curS < param.s && curC < param.c){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; @@ -348,6 +352,5 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } - #define CUDA_CONV3D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From e802036eb59306143b2a8afb84731e80976e2e5f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 11:16:45 -0500 Subject: [PATCH 04/19] conv3d WIP: added a test case --- tests/CMakeLists.txt | 1 + tests/test-conv3d.cpp | 429 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 430 insertions(+) create mode 100644 tests/test-conv3d.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..d1e5d12f2c248 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -202,6 +202,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) +llama_build_and_test(test-conv3d.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp new file mode 100644 index 0000000000000..b29a039b8f581 --- /dev/null +++ b/tests/test-conv3d.cpp @@ -0,0 +1,429 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +//#include +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int kw = 3, int kh = 3, int kd = 3, bool use_gpu = false ) { + // create data + int KW = kw, KH = kh, KD = kd; + int IC = ic, OC = oc; + int IW = iw, IH = ih, ID = id, N = 1; + srand(time(NULL)); + + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + + // Initialize adata + std::vector adata(KW * KH * KD * IC * OC); + for (int i = 0; i < KW * KH * KD * IC * OC; i++) { + // adata[i] = 2.f; + // adata[i] = (float)(i%KW)-1.f; + // adata[i] = (rand() % 255) / 255.0; + float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + adata[i] = r; + } + + // Convert adata to fp16 format + std::vector hadata(KW * KH * KD * IC * OC); + ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * KD * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * ID * IC * N); + for (int i = 0; i < IW * IH * ID * IC * N; i++) { + // bdata[i] = (float)(i%IW)/10.f; + // bdata[i] = 1.5f; + // bdata[i] = (rand() % 255) / 255.0; + float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + bdata[i] = r; + } + + size_t buffer_size = 0; + { + // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + buffer_size += IW * IH * ID * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + // fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, KD, IC*OC); + // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, ID, IC*N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, + const int64_t i0, const int64_t i1, const int64_t i2); + +struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int s2 = 1; + int p0 = 1; + int p1 = 1; + int p2 = 1; + int d0 = 1; + int d1 = 1; + int d2 = 1; + + + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_3d(ctx0, model.a, model.b, ic, s0, s1, s2, p0, p1, p2, d0, d1, d2); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int s2 = 1; + int p0 = 1; + int p1 = 1; + int p2 = 1; + int d0 = 1; + int d1 = 1; + int d2 = 1; + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + struct ggml_tensor* wino_res = ggml_conv_3d_direct(ctx0, model.a, model.b, + s0, s1, s2, p0, p1, p2, d0, d1, d2, + ic, n, oc); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, + const int64_t ic, const int64_t n, const int64_t oc, double *t) { + + struct ggml_cgraph * gf = build_graph(model, ic, n, oc); + + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter data(ggml_nelements(res)); + ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); + + *t = time_us/1000; + return data; + +} + + +int main(void) +{ + ggml_time_init(); + std::vector> configs = { + // std::make_tuple(64,64,48,64,3,3), + // std::make_tuple(320,320,104,152,3,3), + // std::make_tuple(640,640,52,76,3,3), + // std::make_tuple(640,640,104,152,3,3), + // std::make_tuple(960,320,104,152,3,3), + // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(320,1280,26,38,8,3,3,3), + // std::make_tuple(1280,1280,26,38,8,3,3,3), + // std::make_tuple(320,1280,52,76,8,3,3,3), + // std::make_tuple(1280,1280,52,76,8,3,3,3), + // std::make_tuple(1280,1280,26,38,1,1), + // std::make_tuple(256,128,768,1024,3,3), + // std::make_tuple(128,3,768,1024,3,3), + // std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(512,256,384,512,1,1), + // std::make_tuple(1280,640,52,76,3,3), + // std::make_tuple(1920,1280,26,38,3,3), + // std::make_tuple(2560,1280,26,38,3,3), + // std::make_tuple(320,1280,26,38,3,3), + // std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,208,304,3,3), + // std::make_tuple(512,256,416,608,3,3), + // std::make_tuple(256,128,832,1216,3,3), + // std::make_tuple(256,256,832,1216,3,3), + // std::make_tuple(320,256,1024,1920) + }; + + int k = 0; + + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), + std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), true); + + ggml_gallocr_t allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph_0(model, std::get<0>(c), 0, 0); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 20; + + double run_time0; + std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, + std::get<0>(c), 1, std::get<1>(c), &run_time0); + + ggml_gallocr_free(allocr); + + allocr = NULL; + + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + gf = build_graph_1(model, std::get<0>(c), 1, std::get<1>(c)); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_1 = NULL; + + double run_time1; + // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, + std::get<0>(c), 1, std::get<1>(c), &run_time1); + + if(k==0) { + k = 1; + fprintf(stderr, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + } + + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), + run_time0, mem_size0/1024.0f/1024.0f, + run_time1, mem_size1/1024.0f/1024.0f); + + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 26*38; i++) { + for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } + + // printf("\nPerforming test:\n"); + return 0; +} From a5b68bcea77496cd6ab8cdf0ce25e6778176474d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 12:33:19 -0500 Subject: [PATCH 05/19] conv3D WIP: fixed a launch param bug, results now correct; performace 3x slower than im2col --- ggml/src/ggml-cuda/conv3d-implicit.cu | 2 +- tests/test-conv3d.cpp | 42 ++++++++++++++------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 00aaa568af974..d935eb22ec000 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -1007,7 +1007,7 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const uint TM = conv_shapes[6][CONV_SHAPE]; const uint TN = conv_shapes[7][CONV_SHAPE]; const uint NUM_THREADS = conv_shapes[8][CONV_SHAPE]; - int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number + int blockx = ((P.Od * P.Oh * P.Ow + BM - 1) / BM); // blockx number int blocky = (P.k + BN-1) / BN; // blocky number int blockz = P.n; // blockz number int thready = 1; // thready number per block diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index b29a039b8f581..53e37efd3178e 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -241,7 +241,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co ic, n, oc); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; + // int64_t *ne = wino_res->ne; // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); ggml_free(ctx0); return gf; @@ -323,9 +323,13 @@ int main(void) // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), std::make_tuple(320,1280,26,38,8,3,3,3), - // std::make_tuple(1280,1280,26,38,8,3,3,3), - // std::make_tuple(320,1280,52,76,8,3,3,3), - // std::make_tuple(1280,1280,52,76,8,3,3,3), + std::make_tuple(1280,1280,26,38,8,3,3,3), + std::make_tuple(320,1280,52,76,8,3,3,3), + std::make_tuple(1280,1280,52,76,8,3,3,3), + std::make_tuple(320,1280,104,152,8,3,3,3), + std::make_tuple(1280,1280,104,152,8,3,3,3), + std::make_tuple(320,1280,208,304,4,3,3,3), + std::make_tuple(640,1280,208,304,4,3,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), // std::make_tuple(128,3,768,1024,3,3), @@ -393,29 +397,27 @@ int main(void) if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| (IC, OC, IW, IH, ID, KW, KH, KD) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", - std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), + std::get<3>(c), std::get<4>(c), std::get<5>(c), + std::get<6>(c), std::get<7>(c), run_time0, mem_size0/1024.0f/1024.0f, run_time1, mem_size1/1024.0f/1024.0f); - // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 3f5c5045da890cc2a0259d96da1cafa47f560f3c Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 15:15:49 -0500 Subject: [PATCH 06/19] conv3d WIP: turn on tensor cores; NCDHW2NDHWC to be worked out --- ggml/src/ggml-cuda/conv3d-implicit.cu | 166 +++++++++++-------------- ggml/src/ggml-cuda/conv3d-implicit.cuh | 150 ++++++++++++++++------ tests/test-conv3d.cpp | 16 +-- 3 files changed, 192 insertions(+), 140 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index d935eb22ec000..4f01dfea8db40 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -62,28 +62,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } -template -__device__ int4 inputIndices(const uint kidx, param_t param) { - - const uint cur0 = fastdiv(kidx, - layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset - const uint cur0_res = fastmodulo(kidx, - layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset - const uint cur1 = fastdiv(cur0_res, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset - const uint cur1_res = fastmodulo(cur0_res, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset - const uint cur2 = fastdiv(cur1_res, - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur3 = fastmodulo(cur1_res, - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur3 : cur0; - const uint curT = layout == 0 ? cur0 : cur1; - const uint curR = layout == 0 ? cur1 : cur2; - const uint curS = layout == 0 ? cur2 : cur3; - return make_int4(curC, curT, curR, curS); - -} template __device__ __forceinline__ void ldmatrix_a( @@ -805,13 +782,16 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const param_t param) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -constexpr unsigned int MMA_M = 16; -constexpr unsigned int MMA_N = 8; + constexpr unsigned int MMA_M = 16; + constexpr unsigned int MMA_N = 8; + const uint PQZ = param.Oh * param.Ow * param.Od; - const unsigned int K = param.c * param.r * param.s; + const unsigned int K = param.c * param.r * param.s * param.t; + const uint weightKOffset = K; //param.c * param.r * param.s * param.t; const uint inChannelOffset = param.c * param.w; - const uint weightKOffset = param.c * param.r * param.s; + const uint inDepthOffset = param.h * param.c * param.w; + const uint inNOffset = param.c * param.w * param.h * param.d; // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; @@ -863,7 +843,7 @@ constexpr unsigned int MMA_N = 8; const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); int offset_direction = 1; @@ -874,7 +854,8 @@ constexpr unsigned int MMA_N = 8; if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, + inNOffset, inDepthOffset, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); @@ -954,10 +935,13 @@ constexpr unsigned int MMA_N = 8; for (int j = 0; j < 4; ++j){ const uint row = m_idx + subk + i * WN / 2; const uint gemm_i = n_idx + j*32; - const int n = fastdiv(gemm_i, param.OHOW_fastdiv); - const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if(n < param.n && row < param.k && col < param.Oh * param.Ow){ - const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + // const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + const int n = fastdiv(gemm_i, param.PQZ_fastdiv); + const int col = fastmodulo(gemm_i, param.PQZ_fastdiv); + if(n < param.n && row < param.k && col < PQZ){ + // const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + const uint outOffset = (n * param.k + row) * PQZ + col; uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); output[outOffset] = smemoutput[idx]; @@ -974,8 +958,6 @@ constexpr unsigned int MMA_N = 8; #endif } -#endif - #define NUM_VARIANTS 4 /* @@ -1021,64 +1003,64 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { - // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { - - // int id = ggml_cuda_get_device(); - - // int64_t ne = P.c * P.h * P.w * P.n; - // int64_t ne00 = P.c; - // int64_t ne01 = P.h * P.w; - // ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); - - // dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - // (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - // dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); - // NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); - - // ne = P.c * P.r * P.s * P.k; - // ne01 = P.r * P.s; - // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); - // dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - // (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - - // const half *X_H = input_f16.get(); - // const half *K_H = kernel_f16.get(); - // ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); - - // constexpr unsigned int BM_dim = 256; - // constexpr unsigned int BN_dim = 256; - // constexpr unsigned int BK_dim = 32; - - // constexpr unsigned int WARPS_PER_BLOCK_M = 2; - // constexpr unsigned int WARPS_PER_BLOCK_N = 4; - // constexpr unsigned int WARPS_PER_BLOCK_K = 4; - - // constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; - // constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; - // constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; - // const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; - // const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; - // constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; - // constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; - // constexpr unsigned int NumThreads = ThreadsM * ThreadsN; - // const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - - // cudaFuncSetAttribute(conv3d_implicit_kernel, - // cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - // dim3 gridDim(BlocksN, BlocksM); - // dim3 blockDim(ThreadsN, ThreadsM); - - // conv3d_implicit_kernel - // <<>>(X_H, K_H, Y_H.get(), P); - // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - // to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); - // } else{ + if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { + + int id = ggml_cuda_get_device(); + + int64_t ne = P.c * P.h * P.w * P.n; + int64_t ne00 = P.c; + int64_t ne01 = P.h * P.w; + ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + + dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + + ne = P.c * P.r * P.s * P.k; + ne01 = P.r * P.s; + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + + const half *X_H = input_f16.get(); + const half *K_H = kernel_f16.get(); + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Od *P.Oh * P.Ow * P.n); + + constexpr unsigned int BM_dim = 256; + constexpr unsigned int BN_dim = 256; + constexpr unsigned int BK_dim = 32; + + constexpr unsigned int WARPS_PER_BLOCK_M = 2; + constexpr unsigned int WARPS_PER_BLOCK_N = 4; + constexpr unsigned int WARPS_PER_BLOCK_K = 4; + + constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim; + const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + + cudaFuncSetAttribute(conv3d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + conv3d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow *P.Od * P.n, st); + } else{ conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); - // } + } } diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index 4e14c15cd2235..04fc9109ed90b 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -35,6 +35,29 @@ typedef struct{ } param_t; +template +__device__ __forceinline__ int4 inputIndices(const unsigned int kidx, param_t param) { + + const unsigned int cur0 = fastdiv(kidx, + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + const unsigned int cur0_res = fastmodulo(kidx, + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(cur0_res, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + const unsigned int cur1_res = fastmodulo(cur0_res, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + const unsigned int cur2 = fastdiv(cur1_res, + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur3 = fastmodulo(cur1_res, + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur3 : cur0; + const unsigned int curT = layout == 0 ? cur0 : cur1; + const unsigned int curR = layout == 0 ? cur1 : cur2; + const unsigned int curS = layout == 0 ? cur2 : cur3; + return make_int4(curC, curT, curR, curS); + +} + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template(kidx, param); + const int curC = curIdx.x; + const int curT = curIdx.y; + const int curR = curIdx.z; + const int curS = curIdx.w; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + kidx; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + // TODO: move some checks outside of loop? + if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -102,6 +131,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, // const unsigned int src_stride, + const unsigned int inNOffset, + const unsigned int inDepthOffset, const unsigned int inChannelOffset, param_t param ) @@ -129,28 +160,43 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + const unsigned int kidx = thread_col*8; + const int4 curIdx = inputIndices<0>(kidx, param); - #pragma unroll +#pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; - unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - int curH = posh_ori + curR * param.dilation1; // input h - int curW = posw_ori + curS * param.dilation0; // input w + unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv); + const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); + const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv); + const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; + const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; + // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; + // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; + // unsigned int inOffset = n * inNOffset; + // const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // int curH = posh_ori + curR * param.dilation1; // input h + // int curW = posw_ori + curS * param.dilation0; // input w + + const int curD = posd_ori + curIdx.y * param.dilation2; // input d + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c){ - const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ + int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c){ + // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_float4[dst_index] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -174,6 +220,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, + const unsigned int inNOffset, + const unsigned int inDepthOffset, const unsigned int inChannelOffset, param_t param ){ @@ -196,23 +244,38 @@ __device__ __forceinline__ void tileMemcpyLoadA( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + const unsigned int kidx = block_k + thread_col*8; + const int4 curIdx = inputIndices<0>(kidx, param); + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; - unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - int curH = posh_ori + curR * param.dilation1; // input h - int curW = posw_ori + curS * param.dilation0; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c){ - const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; + // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; + unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv); + const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); + const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv); + const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; + const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; + // unsigned int inOffset = n * param.c * param.h * param.w; + // const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // int curH = posh_ori + curR * param.dilation1; // input h + // int curW = posw_ori + curS * param.dilation0; // input w + const int curD = posd_ori + curIdx.y * param.dilation2; // input d + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ + int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c){ + // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_reg[i] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -259,14 +322,21 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // - + // const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int kidx = block_k + thread_col*8; + const int4 curIdx = inputIndices<0>(kidx, param); + const int curC = curIdx.x; + const int curT = curIdx.y; + const int curR = curIdx.z; + const int curS = curIdx.w; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + // if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + // TODO : move some checks outside of the loop + if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 53e37efd3178e..8b19f05c39fa2 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -323,13 +323,13 @@ int main(void) // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), std::make_tuple(320,1280,26,38,8,3,3,3), - std::make_tuple(1280,1280,26,38,8,3,3,3), - std::make_tuple(320,1280,52,76,8,3,3,3), - std::make_tuple(1280,1280,52,76,8,3,3,3), - std::make_tuple(320,1280,104,152,8,3,3,3), - std::make_tuple(1280,1280,104,152,8,3,3,3), - std::make_tuple(320,1280,208,304,4,3,3,3), - std::make_tuple(640,1280,208,304,4,3,3,3), + // std::make_tuple(1280,1280,26,38,8,3,3,3), + // std::make_tuple(320,1280,52,76,8,3,3,3), + // std::make_tuple(1280,1280,52,76,8,3,3,3), + // std::make_tuple(320,1280,104,152,8,3,3,3), + // std::make_tuple(1280,1280,104,152,8,3,3,3), + // std::make_tuple(320,1280,208,304,4,3,3,3), + // std::make_tuple(640,1280,208,304,4,3,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), // std::make_tuple(128,3,768,1024,3,3), @@ -367,7 +367,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, From 3308ccef918833ad60a3217961a94b393c4ca562 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 17:30:36 -0500 Subject: [PATCH 07/19] conv3d WIP: enabled tensor core path --- ggml/src/ggml-cuda/conv3d-implicit.cu | 8 ++++---- tests/test-conv3d.cpp | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 4f01dfea8db40..76f887972a023 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -1007,9 +1007,9 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa int id = ggml_cuda_get_device(); - int64_t ne = P.c * P.h * P.w * P.n; + int64_t ne = P.c * P.d * P.h * P.w * P.n; int64_t ne00 = P.c; - int64_t ne01 = P.h * P.w; + int64_t ne01 = P.h * P.w * P.d; ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, @@ -1018,8 +1018,8 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); - ne = P.c * P.r * P.s * P.k; - ne01 = P.r * P.s; + ne = P.c * P.r * P.s * P.t * P.k; + ne01 = P.r * P.s * P.t; ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 8b19f05c39fa2..53e37efd3178e 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -323,13 +323,13 @@ int main(void) // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), std::make_tuple(320,1280,26,38,8,3,3,3), - // std::make_tuple(1280,1280,26,38,8,3,3,3), - // std::make_tuple(320,1280,52,76,8,3,3,3), - // std::make_tuple(1280,1280,52,76,8,3,3,3), - // std::make_tuple(320,1280,104,152,8,3,3,3), - // std::make_tuple(1280,1280,104,152,8,3,3,3), - // std::make_tuple(320,1280,208,304,4,3,3,3), - // std::make_tuple(640,1280,208,304,4,3,3,3), + std::make_tuple(1280,1280,26,38,8,3,3,3), + std::make_tuple(320,1280,52,76,8,3,3,3), + std::make_tuple(1280,1280,52,76,8,3,3,3), + std::make_tuple(320,1280,104,152,8,3,3,3), + std::make_tuple(1280,1280,104,152,8,3,3,3), + std::make_tuple(320,1280,208,304,4,3,3,3), + std::make_tuple(640,1280,208,304,4,3,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), // std::make_tuple(128,3,768,1024,3,3), @@ -367,7 +367,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, From 2357922a2f5d4c74abdf4f80e3cb8c4fc1da3c99 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 08:46:17 -0500 Subject: [PATCH 08/19] fixed a bug now all test cases passed --- ggml/src/ggml-cuda/conv3d-implicit.cu | 120 +++---------------------- ggml/src/ggml-cuda/conv3d-implicit.cuh | 60 +++---------- tests/test-conv3d.cpp | 105 +++++++++++----------- 3 files changed, 76 insertions(+), 209 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 76f887972a023..df5ed4578ad1f 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -163,7 +163,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint inKOffset = start_k + innerColA * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - const unsigned int gemm_i = bx * BM + innerRowA + offset; + const unsigned int gemm_i = bx * BM + innerRowA + offset; // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); @@ -173,26 +173,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; int inOffset = n * inNOffset; if(vec_load){ - // const uint cur0 = fastdiv(inKOffset, - // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset - // const uint cur0_res = fastmodulo(inKOffset, - // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset - // const uint cur1 = fastdiv(cur0_res, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset - // const uint cur1_res = fastmodulo(cur0_res, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset - // const uint cur2 = fastdiv(cur1_res, - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint cur3 = fastmodulo(cur1_res, - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint curC = layout == 0 ? cur3 : cur0; - // const uint curT = layout == 0 ? cur0 : cur1; - // const uint curR = layout == 0 ? cur1 : cur2; - // const uint curS = layout == 0 ? cur2 : cur3; const int4 curIdx = inputIndices(inKOffset, param); - // const int curD = posd_ori + curT * param.dilation2; // input w - // const int curH = posh_ori + curR * param.dilation1; // input h - // const int curW = posw_ori + curS * param.dilation0; // input w const int curD = posd_ori + curIdx.y * param.dilation2; // input w const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w @@ -214,43 +195,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - // const uint cur0 = fastdiv(inKOffset + i, - // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset - // const uint cur0_res = fastmodulo(inKOffset + i, - // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset - // const uint cur1 = fastdiv(cur0_res, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset - // const uint cur1_res = fastmodulo(cur0_res, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset - // const uint cur2 = fastdiv(cur1_res, - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint cur3 = fastmodulo(cur1_res, - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint curC = layout == 0 ? cur3 : cur0; - // const uint curT = layout == 0 ? cur0 : cur1; - // const uint curR = layout == 0 ? cur1 : cur2; - // const uint curS = layout == 0 ? cur2 : cur3; const int4 curIdx = inputIndices(inKOffset + i, param); - // const int curD = posd_ori + curT * param.dilation2; // input w - // const int curH = posh_ori + curR * param.dilation1; // input h - // const int curW = posw_ori + curS * param.dilation0; // input w const int curD = posd_ori + curIdx.y * param.dilation2; // input w const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; - // const uint cur0 = fastdiv(start_k + innerColA * 4 + i, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - // const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint curC = layout == 0 ? cur2 : cur0; - // const uint curR = layout == 0 ? cur0 : cur1; - // const uint curS = layout == 0 ? cur1 : cur2; - // const int curH = posh_ori + curR * param.d_h; // input h - // const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){ int inOffsetTmp = layout == 0 ? curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: @@ -360,12 +309,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint inKkOffset = innerColA * 4 + crs + BK; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - // const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - // const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - // const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - // int inOffset = n * param.c * param.h * param.w ; - const unsigned int gemm_i = bx * BM + innerRowA + offset; + const unsigned int gemm_i = bx * BM + innerRowA + offset; int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2; @@ -379,28 +323,10 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; - // const uint cur0 = fastdiv(innerColA * 4 + crs + BK, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint curC = layout == 0 ? cur2 : cur0; - // const uint curR = layout == 0 ? cur0 : cur1; - // const uint curS = layout == 0 ? cur1 : cur2; - - // const int curH = posh_ori + curR * param.d_h; // input h - // const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){ int inOffsetTmp = layout == 0 ? curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; - // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){ - // int inOffsetTmp = layout == 0 ? - // curH * inChannelOffset + curW * param.c + curC: - // curC * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; @@ -414,29 +340,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - // const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, - // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - // const uint curC = layout == 0 ? cur2 : cur0; - // const uint curR = layout == 0 ? cur0 : cur1; - // const uint curS = layout == 0 ? cur1 : cur2; - - // const int curH = posh_ori + curR * param.d_h; // input h - // const int curW = posw_ori + curS * param.d_w; // input w const int4 curIdx = inputIndices(inKkOffset + i, param); const int curD = posd_ori + curIdx.y * param.dilation2; // input w const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; - // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ - // int inOffsetTmp = layout == 0 ? - // curH * inChannelOffset + curW * param.c + curC: - // curC * inChannelOffset + curH * param.w + curW; if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){ int inOffsetTmp = layout == 0 ? curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: @@ -521,7 +429,6 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i; if (n < param.n && row < param.k && col < PQZ){ const uint outOffset = ksplit > 0 ? - // z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col : ((z * param.n + n) * param.k + row) * PQZ + col : (z * param.k + row) * PQZ + col; output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; @@ -790,7 +697,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const unsigned int K = param.c * param.r * param.s * param.t; const uint weightKOffset = K; //param.c * param.r * param.s * param.t; const uint inChannelOffset = param.c * param.w; - const uint inDepthOffset = param.h * param.c * param.w; + const uint inDepthOffset = param.h * param.c * param.w; const uint inNOffset = param.c * param.w * param.h * param.d; // loop bounds, constexpr where possible allows for loop unrolling @@ -854,7 +761,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inNOffset, inDepthOffset, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } @@ -935,12 +842,9 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, for (int j = 0; j < 4; ++j){ const uint row = m_idx + subk + i * WN / 2; const uint gemm_i = n_idx + j*32; - // const int n = fastdiv(gemm_i, param.OHOW_fastdiv); - // const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); const int n = fastdiv(gemm_i, param.PQZ_fastdiv); const int col = fastmodulo(gemm_i, param.PQZ_fastdiv); if(n < param.n && row < param.k && col < PQZ){ - // const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; const uint outOffset = (n * param.k + row) * PQZ + col; uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); @@ -1109,19 +1013,15 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint KW = kernel->ne[0]; // kernel_w const uint KH = kernel->ne[1]; // kernel_h const uint KD = kernel->ne[2]; // kernel_h - // const uint IC = input->ne[2]; // input_channels - // const uint OC = kernel->ne[3]; // ouptut_chanles - // const uint B = input->ne[3]; // n_batches - - param_t params = { B, - IC, + param_t params = { B, + IC, IH, IW, ID, - OC, + OC, KH, KW, KD, - ST_Y, ST_X, ST_Z, - PD_Y, PD_X, PD_Z, - DL_Y, DL_X, DL_Z, + ST_X, ST_Y, ST_Z, + PD_X, PD_Y, PD_Z, + DL_X, DL_Y, DL_Z, OH, OW, OD, init_fastdiv_values(KW*IC), init_fastdiv_values(OW), diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index 04fc9109ed90b..9cd7fe4e9b290 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -11,15 +11,15 @@ typedef struct{ unsigned int r; //filter height unsigned int s; //filter width unsigned int t; //filter depth - unsigned int stride0; //stride width - unsigned int stride1; //stride height + unsigned int stride0; //stride width + unsigned int stride1; //stride height unsigned int stride2; //stride depth - unsigned int padding0; //padding width + unsigned int padding0; //padding width unsigned int padding1; //padding height - unsigned int padding2; //padding depth - unsigned int dilation0; //dilation width - unsigned int dilation1; //dilation height - unsigned int dilation2; //dilation depth + unsigned int padding2; //padding depth + unsigned int dilation0; //dilation width + unsigned int dilation1; //dilation height + unsigned int dilation2; //dilation depth unsigned int Oh; //output height unsigned int Ow; //output width unsigned int Od; //output depth @@ -39,17 +39,17 @@ template __device__ __forceinline__ int4 inputIndices(const unsigned int kidx, param_t param) { const unsigned int cur0 = fastdiv(kidx, - layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); const unsigned int cur0_res = fastmodulo(kidx, - layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); const unsigned int cur1 = fastdiv(cur0_res, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); const unsigned int cur1_res = fastmodulo(cur0_res, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); const unsigned int cur2 = fastdiv(cur1_res, - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + layout == 0 ? param.C_fastdiv : param.S_fastdiv); const unsigned int cur3 = fastmodulo(cur1_res, - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + layout == 0 ? param.C_fastdiv : param.S_fastdiv); const unsigned int curC = layout == 0 ? cur3 : cur0; const unsigned int curT = layout == 0 ? cur0 : cur1; const unsigned int curR = layout == 0 ? cur1 : cur2; @@ -90,9 +90,6 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - // const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // const unsigned int kidx = thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); const int curC = curIdx.x; @@ -172,17 +169,6 @@ __device__ __forceinline__ void tileMemcpySwizzleA( const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv); const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; - // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; - // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; - // unsigned int inOffset = n * inNOffset; - // const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // int curH = posh_ori + curR * param.dilation1; // input h - // int curW = posw_ori + curS * param.dilation0; // input w - const int curD = posd_ori + curIdx.y * param.dilation2; // input d const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w @@ -193,9 +179,6 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; - // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - // curR < param.r && curS < param.s && curC < param.c){ - // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -250,31 +233,18 @@ __device__ __forceinline__ void tileMemcpyLoadA( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; - // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv); const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2; const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv); const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; - // unsigned int inOffset = n * param.c * param.h * param.w; - // const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // int curH = posh_ori + curR * param.dilation1; // input h - // int curW = posw_ori + curS * param.dilation0; // input w const int curD = posd_ori + curIdx.y * param.dilation2; // input d const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; - // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - // curR < param.r && curS < param.s && curC < param.c){ - // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -322,9 +292,6 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - // const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // const unsigned int kidx = block_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); const int curC = curIdx.x; @@ -334,7 +301,6 @@ __device__ __forceinline__ void tileMemcpyLoadB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; - // if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ // TODO : move some checks outside of the loop if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 53e37efd3178e..92e9d1e457d38 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -38,7 +38,9 @@ struct test_model { -void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int kw = 3, int kh = 3, int kd = 3, bool use_gpu = false ) { +void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, + int kw = 3, int kh = 3, int kd = 3, + bool use_fp16 = true, bool use_gpu = false ) { // create data int KW = kw, KH = kh, KD = kd; int IC = ic, OC = oc; @@ -72,9 +74,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int } size_t buffer_size = 0; - { - // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a - buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + { if(use_fp16) + buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + else + buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a buffer_size += IW * IH * ID * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b buffer_size += 1024; // overhead } @@ -122,8 +125,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int model.ctx = ggml_init(params); // create tensors - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, KD, IC*OC); - // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + if(use_fp16) + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, KD, IC*OC); + else + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, KD, IC*OC); model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, ID, IC*N); // create a allocator @@ -134,11 +139,15 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int // load data to buffer if(ggml_backend_is_cpu(model.backend)) { - memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); - // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + if(use_fp16) + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + else + memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); } else { - ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); - // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + if(use_fp16) + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + else + ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); } // alloc memory @@ -155,7 +164,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int } } -typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, const int64_t i0, const int64_t i1, const int64_t i2); struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) { @@ -173,18 +182,27 @@ struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, co struct ggml_cgraph * gf = ggml_new_graph(ctx0); + // int s0 = 2; + // int s1 = 1; + // int s2 = 1; + // int p0 = 2; + // int p1 = 0; + // int p2 = 1; + // int d0 = 1; + // int d1 = 1; + // int d2 = 2; + int s0 = 1; int s1 = 1; int s2 = 1; int p0 = 1; int p1 = 1; int p2 = 1; + int d0 = 1; int d1 = 1; int d2 = 1; - - // recalculate for avoid fragmentation struct ggml_tensor* conv2d_res = ggml_conv_3d(ctx0, model.a, model.b, ic, s0, s1, s2, p0, p1, p2, d0, d1, d2); ggml_set_name(conv2d_res, "conv2d_res"); @@ -227,6 +245,16 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co int d1 = 1; int d2 = 1; + // int s0 = 2; + // int s1 = 1; + // int s2 = 1; + // int p0 = 2; + // int p1 = 0; + // int p2 = 1; + // int d0 = 1; + // int d1 = 1; + // int d2 = 2; + // recalculate for avoid fragmentation // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); // ggml_set_name(conv2d_res, "conv2d_res"); @@ -236,7 +264,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co // struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - struct ggml_tensor* wino_res = ggml_conv_3d_direct(ctx0, model.a, model.b, + struct ggml_tensor* wino_res = ggml_conv_3d_direct(ctx0, model.a, model.b, s0, s1, s2, p0, p1, p2, d0, d1, d2, ic, n, oc); ggml_set_name(wino_res, "wino_res"); @@ -251,7 +279,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, - build_graph_t build_graph, int iters, + build_graph_t build_graph, int iters, const int64_t ic, const int64_t n, const int64_t oc, double *t) { struct ggml_cgraph * gf = build_graph(model, ic, n, oc); @@ -271,7 +299,6 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr } #endif - ggml_backend_graph_compute(model.backend, gf); @@ -289,8 +316,6 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr double time_us = end_time - start_time; time_us = time_us/iters; - // printf(" Taking %f ms\n ", time_us/1000); - //ggml_graph_print(gf); struct ggml_tensor *res = NULL; @@ -316,12 +341,6 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), std::make_tuple(320,1280,26,38,8,3,3,3), std::make_tuple(1280,1280,26,38,8,3,3,3), std::make_tuple(320,1280,52,76,8,3,3,3), @@ -330,29 +349,14 @@ int main(void) std::make_tuple(1280,1280,104,152,8,3,3,3), std::make_tuple(320,1280,208,304,4,3,3,3), std::make_tuple(640,1280,208,304,4,3,3,3), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(128,3,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(512,256,384,512,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), - // std::make_tuple(320,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), - // std::make_tuple(320,256,1024,1920) }; int k = 0; for (auto c : configs){ test_model model; - load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), - std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), true); + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), + std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), true, true); ggml_gallocr_t allocr = NULL; allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); @@ -366,11 +370,11 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - struct ggml_cgraph * gf_res_0 = NULL; + struct ggml_cgraph * gf_res_0 = NULL; int iterations = 20; double run_time0; - std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, + std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, std::get<0>(c), 1, std::get<1>(c), &run_time0); ggml_gallocr_free(allocr); @@ -386,23 +390,22 @@ int main(void) ggml_gallocr_reserve(allocr, gf); size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - struct ggml_cgraph * gf_res_1 = NULL; + struct ggml_cgraph * gf_res_1 = NULL; double run_time1; // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, + std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, std::get<0>(c), 1, std::get<1>(c), &run_time1); - if(k==0) { + if(k==0) { k = 1; fprintf(stderr, "| (IC, OC, IW, IH, ID, KW, KH, KD) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", - std::get<0>(c), std::get<1>(c), std::get<2>(c), + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), run_time0, mem_size0/1024.0f/1024.0f, @@ -412,7 +415,7 @@ int main(void) // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // printf("(%7.3f, %7.3f, %f, %d) \n", // im2col_data[i], conv2d_data[i], // diff, i); // // break; @@ -425,7 +428,5 @@ int main(void) ggml_gallocr_free(allocr); } - - // printf("\nPerforming test:\n"); return 0; } From 5aa4ae739d5951e3a8b099d2a5dad17afee328d1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 09:28:05 -0500 Subject: [PATCH 09/19] make CI happy --- tests/test-conv3d.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 92e9d1e457d38..517da9ce6ff3b 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -28,6 +28,7 @@ static void ggml_log_callback_default(ggml_log_level level, const char * text, v fflush(stderr); } + struct test_model { struct ggml_tensor * a; struct ggml_tensor * b; @@ -36,6 +37,14 @@ struct test_model { struct ggml_context * ctx; }; +struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc); +struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc); +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, + const int64_t i0, const int64_t i1, const int64_t i2); + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, + const int64_t ic, const int64_t n, const int64_t oc, double *t); void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, @@ -101,6 +110,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } } +#else + GGML_UNUSED(use_gpu); #endif #ifdef GGML_USE_METAL @@ -112,6 +123,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } } +#else + GGML_UNUSED(use_gpu); #endif if(!model.backend) { @@ -164,10 +177,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, } } -typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, - const int64_t i0, const int64_t i1, const int64_t i2); - struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) { + + GGML_UNUSED(n); + GGML_UNUSED(oc); + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -370,7 +384,6 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - struct ggml_cgraph * gf_res_0 = NULL; int iterations = 20; double run_time0; @@ -389,12 +402,8 @@ int main(void) // compute the required memory ggml_gallocr_reserve(allocr, gf); size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - struct ggml_cgraph * gf_res_1 = NULL; double run_time1; - // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, std::get<0>(c), 1, std::get<1>(c), &run_time1); From 91650b7fdcaf38f0de09eabd8712d9023114a27b Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 09:30:51 -0500 Subject: [PATCH 10/19] one more CI fix --- tests/test-conv3d.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 517da9ce6ff3b..a004798522656 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -37,6 +37,9 @@ struct test_model { struct ggml_context * ctx; }; +void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, + int kw = 3, int kh = 3, int kd = 3, + bool use_fp16 = true, bool use_gpu = false ); struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc); struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc); typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, @@ -49,7 +52,7 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int kw = 3, int kh = 3, int kd = 3, - bool use_fp16 = true, bool use_gpu = false ) { + bool use_fp16 = true, bool use_gpu = false ) { // create data int KW = kw, KH = kh, KD = kd; int IC = ic, OC = oc; From f0ced9fb3d5fbdda7a65d5baab50badd4f67c620 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 10:10:51 -0500 Subject: [PATCH 11/19] add some test cases in test-backend-op perf --- tests/test-backend-ops.cpp | 44 ++++++++++++++++++++++++++++++++++++++ tests/test-conv3d.cpp | 4 ++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 04fa1b62d3b4d..f8a9b8f4b8dbe 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7374,6 +7374,50 @@ static std::vector> make_test_cases_perf() { } } + + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (int N : {1}) { + for (int IC : {320, 640}) { + for (int OC : {320, 640}) { + for (int s0 : {1}) { + for (int p1 : {1}) { + for (int d2 : {1}) { + int64_t IW = 26, IH = 38, ID = 8; + int64_t KW = 3, KH = 3, KD = 3; + int s1 = s0, s2 = s0; + int p0 = p1, p2 = p1; + int d0 = d2, d1 = d2; + test_cases.emplace_back(new test_conv_3d( + N, IC, ID, IH, IW, + OC, KD, KH, KW, + s0, s1, s2, p0, p1, p2, d0, d1, d2, + kernel_type)); + IW = 52; IH = 76; + test_cases.emplace_back(new test_conv_3d( + N, IC, ID, IH, IW, + OC, KD, KH, KW, + s0, s1, s2, p0, p1, p2, d0, d1, d2, + kernel_type)); + IW = 104; IH = 158; + test_cases.emplace_back(new test_conv_3d( + N, IC, ID, IH, IW, + OC, KD, KH, KW, + s0, s1, s2, p0, p1, p2, d0, d1, d2, + kernel_type)); + IW = 208; IH = 316; + test_cases.emplace_back(new test_conv_3d( + N, IC, ID, IH, IW, + OC, KD, KH, KW, + s0, s1, s2, p0, p1, p2, d0, d1, d2, + kernel_type)); + } + } + } + } + } + } + } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index a004798522656..f5c23aa94055a 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -38,8 +38,8 @@ struct test_model { }; void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, - int kw = 3, int kh = 3, int kd = 3, - bool use_fp16 = true, bool use_gpu = false ); + int kw, int kh, int kd, + bool use_fp16, bool use_gpu); struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc); struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc); typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model, From f9212865a92420a270925dc31d296e066e49c07f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 10:40:29 -0500 Subject: [PATCH 12/19] avoid CI time out on test-conv3d --- tests/test-conv3d.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index f5c23aa94055a..97c61c77cbe39 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -358,14 +358,15 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(320,1280,26,38,8,3,3,3), - std::make_tuple(1280,1280,26,38,8,3,3,3), - std::make_tuple(320,1280,52,76,8,3,3,3), - std::make_tuple(1280,1280,52,76,8,3,3,3), - std::make_tuple(320,1280,104,152,8,3,3,3), - std::make_tuple(1280,1280,104,152,8,3,3,3), - std::make_tuple(320,1280,208,304,4,3,3,3), - std::make_tuple(640,1280,208,304,4,3,3,3), + std::make_tuple(1,2,16,32,4,3,3,3), + // std::make_tuple(320,1280,26,38,8,3,3,3), + // std::make_tuple(1280,1280,26,38,8,3,3,3), + // std::make_tuple(320,1280,52,76,8,3,3,3), + // std::make_tuple(1280,1280,52,76,8,3,3,3), + // std::make_tuple(320,1280,104,152,8,3,3,3), + // std::make_tuple(1280,1280,104,152,8,3,3,3), + // std::make_tuple(320,1280,208,304,4,3,3,3), + // std::make_tuple(640,1280,208,304,4,3,3,3), }; int k = 0; From 36c0df79041916b0293983152daed8bf4e9cf130 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 10:52:47 -0500 Subject: [PATCH 13/19] fix metal related CI stuff --- tests/test-conv3d.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 97c61c77cbe39..0beeca43951fd 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -120,7 +120,6 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, #ifdef GGML_USE_METAL if (use_gpu) { fprintf(stderr, "%s: using Metal backend\n", __func__); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); model.backend = ggml_backend_metal_init(); if (!model.backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); @@ -310,13 +309,6 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr ggml_backend_cpu_set_n_threads(model.backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - - ggml_backend_graph_compute(model.backend, gf); ggml_backend_synchronize(model.backend); From d2d814c15663fdefeb9adedc683b6f26c9aed9f0 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 9 Nov 2025 17:30:08 -0500 Subject: [PATCH 14/19] fixed a bug in calculating filter row index --- ggml/src/ggml-cuda/conv3d-implicit.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index 9cd7fe4e9b290..37449f677e9b1 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -104,7 +104,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); // TODO: move some checks outside of loop? - if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -302,7 +302,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; // TODO : move some checks outside of the loop - if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); From a428feecdd2d6a5643038854f7dc1b43ecba2358 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 13:13:36 -0500 Subject: [PATCH 15/19] fuse cast to float into conv epilogue; improve swizzling for output --- ggml/src/ggml-cuda/conv3d-implicit.cu | 49 ++++++++++++++++----------- tests/test-conv3d.cpp | 6 ++-- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index df5ed4578ad1f..89e9ebf2a6b90 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -681,11 +681,11 @@ __device__ __forceinline__ void ldmatrix_b( #endif } -template static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, - half * __restrict__ output, + T * __restrict__ output, const param_t param) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -828,27 +828,36 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; - dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); + idx = (idx + 8 * BN / 2 ) ^ 0b010; + dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[1]; } } __syncthreads(); #pragma unroll - for (int subk = 0; subk < WN / 2; ++subk){ + for (int subk = 0; subk < WN / 4; ++subk){ + const uint row = m_idx + subk*2 + i * WN / 2; + uint idx = output_lds_addr + subk*2; // + j*32*BN/2; + idx = idx ^ ((idx & 0b110000000000) >> 9); + idx = idx ^ ((idx & 0b1110000000) >> 4); for (int j = 0; j < 4; ++j){ - const uint row = m_idx + subk + i * WN / 2; const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.PQZ_fastdiv); const int col = fastmodulo(gemm_i, param.PQZ_fastdiv); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); + half (&res_)[2] = reinterpret_cast(dst_ptr); if(n < param.n && row < param.k && col < PQZ){ const uint outOffset = (n * param.k + row) * PQZ + col; - uint idx = output_lds_addr + subk + j*32*BN/2; - idx = idx ^ ((idx & 0b1110000000) >> 4); - output[outOffset] = smemoutput[idx]; + output[outOffset] = ggml_cuda_cast(res_[0]); + } + if(n < param.n && row+1 < param.k && col < PQZ){ + const uint outOffset = (n * param.k + row + 1) * PQZ + col; + output[outOffset] = ggml_cuda_cast(res_[1]); } } } @@ -924,15 +933,17 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ne = P.c * P.r * P.s * P.t * P.k; ne01 = P.r * P.s * P.t; - ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); - dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); + if(ne01 > 1){ + kernel_f16.alloc(ne); + dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } const half *X_H = input_f16.get(); - const half *K_H = kernel_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Od *P.Oh * P.Ow * P.n); + const half *K_H = ne01 == 1 ? K_D : kernel_f16.get(); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; @@ -952,16 +963,14 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - cudaFuncSetAttribute(conv3d_implicit_kernel, + cudaFuncSetAttribute(conv3d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); - conv3d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow *P.Od * P.n, st); + <<>>(X_H, K_H, Y_D, P); } else{ conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); } diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 0beeca43951fd..6483841a42c6e 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -350,9 +350,9 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(1,2,16,32,4,3,3,3), + // std::make_tuple(1,2,16,32,4,3,3,3), // std::make_tuple(320,1280,26,38,8,3,3,3), - // std::make_tuple(1280,1280,26,38,8,3,3,3), + std::make_tuple(1280,1280,26,38,8,3,3,3), // std::make_tuple(320,1280,52,76,8,3,3,3), // std::make_tuple(1280,1280,52,76,8,3,3,3), // std::make_tuple(320,1280,104,152,8,3,3,3), @@ -380,7 +380,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, From 15daa5a6a8667d58ce8a02c87b3d790db3e28e55 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 14:38:23 -0500 Subject: [PATCH 16/19] added split-k mode to tensor core path --- ggml/src/ggml-cuda/conv3d-implicit.cu | 159 +++++++++++++++++++++---- ggml/src/ggml-cuda/conv3d-implicit.cuh | 29 +++-- tests/test-conv3d.cpp | 4 +- 3 files changed, 158 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 89e9ebf2a6b90..fd6e1d0b71386 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32; //currently not use; in future for split-k kernels -// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { -// const int row = blockIdx.x; -// const int col = threadIdx.x; - -// float sum = 0.0f; -// if (row * blockDim.x + col < ncols) { -// for (int i = 0; i < nrows; ++i){ -// sum += x[i * ncols + row * blockDim.x + col]; -// } -// dst[row * blockDim.x + col] = sum; -// } -// } +template +static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += ggml_cuda_cast(x[i * ncols + row * blockDim.x + col]); + } + dst[row * blockDim.x + col] = ggml_cuda_cast(sum); + } +} template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -682,7 +683,7 @@ __device__ __forceinline__ void ldmatrix_b( } template + const int WK, const int ksplit, const int NUM_THREADS> static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, T * __restrict__ output, @@ -699,12 +700,19 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const uint inChannelOffset = param.c * param.w; const uint inDepthOffset = param.h * param.c * param.w; const uint inNOffset = param.c * param.w * param.h * param.d; + const unsigned int z = blockIdx.z; // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; - const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; + + const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const unsigned int start_k = (ksplit > 0) ? z * ks : 0; + const unsigned int end_k = min(start_k + ks, weightKOffset); + const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + +// const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -750,8 +758,8 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, start_k, end_k, inNOffset, inDepthOffset, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -761,9 +769,10 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inNOffset, inDepthOffset, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, + weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); half* B_warp_tile = B_block_smem + (warp_n * WN * BK); @@ -852,12 +861,28 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); half (&res_)[2] = reinterpret_cast(dst_ptr); if(n < param.n && row < param.k && col < PQZ){ - const uint outOffset = (n * param.k + row) * PQZ + col; + // if constexpr (ksplit > 0) { + // const uint outOffset = (n * param.k + row) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[0]); + // } else { + // const uint outOffset = (n * param.k + row) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[0]); + // } + const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row) * PQZ + col : + (n * param.k + row) * PQZ + col; output[outOffset] = ggml_cuda_cast(res_[0]); } if(n < param.n && row+1 < param.k && col < PQZ){ - const uint outOffset = (n * param.k + row + 1) * PQZ + col; + const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row+1) * PQZ + col : + (n * param.k + row+1) * PQZ + col; output[outOffset] = ggml_cuda_cast(res_[1]); + // if constexpr (ksplit > 0) { + // const uint outOffset = (n * param.k + row) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[0]); + // } else { + // const uint outOffset = (n * param.k + row + 1) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[1]); + // } } } } @@ -914,6 +939,33 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, WNITER, TM, TN, NUM_THREADS, 1, false, 0><<>>(X_D, K_D, Y_D, P); } +template +static void launch_conv3d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D, + const unsigned int BlocksM, const unsigned int BlocksN, + const unsigned int shmem_bytes, + const param_t P, cudaStream_t st){ + + int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Od * P.Oh * P.Ow * P.n); + cudaFuncSetAttribute(conv3d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM, ksplit); + dim3 blockDim(ThreadsN, ThreadsM); + + conv3d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow * P.Od; + const unsigned int blockx = (nrows + 511) / 512; + const dim3 block_nums(blockx, 1, 1); + const dim3 block_dims(512, 1, 1); + reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); +} + static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { @@ -956,6 +1008,9 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + + static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free"); + const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim; const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; @@ -963,13 +1018,73 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - cudaFuncSetAttribute(conv3d_implicit_kernel, + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { + if (BlocksM * BlocksN < 2*(unsigned int)nsm){ + int j, max_remaining_waves = -1, candidate = -1; + int ks = min(12, nsm / (BlocksM * BlocksN)); + if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) + ks = 12; + for (j = 2; j <= ks; j++){ + const int remainder = (BlocksM * BlocksN * j) % nsm; + if ((P.c * P.r * P.s * P.t) % (8*j) == 0){ + if (remainder == 0) { + candidate = j; + max_remaining_waves = 0; + break; + } else if (remainder > max_remaining_waves) { + max_remaining_waves = remainder; + candidate = j; + } + } + } + + if(candidate != -1){ + j = candidate; + if (j == 2) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 3) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 4) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 5) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 6) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 7) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 8) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 9) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 10) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 11) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 12) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } + return; + } + } + cudaFuncSetAttribute(conv3d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); conv3d_implicit_kernel + WM_dim, WN_dim, WK_dim, 0, NumThreads> <<>>(X_H, K_H, Y_D, P); } else{ conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index 37449f677e9b1..d7d8ef10868e6 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -65,6 +65,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleB( const half* src, half* dst, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -90,7 +92,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int kidx = thread_col*8; + const unsigned int kidx = start_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); const int curC = curIdx.x; const int curT = curIdx.y; @@ -104,7 +106,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); // TODO: move some checks outside of loop? - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c && kidx < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -127,7 +129,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, - // const unsigned int src_stride, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inNOffset, const unsigned int inDepthOffset, const unsigned int inChannelOffset, @@ -157,7 +160,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int kidx = thread_col*8; + const unsigned int kidx = start_k+thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); #pragma unroll @@ -177,7 +180,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ @@ -203,6 +206,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inNOffset, const unsigned int inDepthOffset, const unsigned int inChannelOffset, @@ -227,7 +232,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int kidx = block_k + thread_col*8; + const unsigned int kidx = start_k + block_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); #pragma unroll @@ -243,7 +248,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d + && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ @@ -270,6 +276,8 @@ __device__ __forceinline__ void tileMemcpyLoadB( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -292,7 +300,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int kidx = block_k + thread_col*8; + const unsigned int kidx = start_k + block_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); const int curC = curIdx.x; const int curT = curIdx.y; @@ -300,9 +308,10 @@ __device__ __forceinline__ void tileMemcpyLoadB( const int curS = curIdx.w; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + kidx; // TODO : move some checks outside of the loop - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t + && curC < param.c && kidx < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 6483841a42c6e..e05bc50090356 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -350,9 +350,9 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(1,2,16,32,4,3,3,3), + std::make_tuple(1,2,16,32,4,3,3,3), // std::make_tuple(320,1280,26,38,8,3,3,3), - std::make_tuple(1280,1280,26,38,8,3,3,3), + // std::make_tuple(1280,1280,26,38,8,3,3,3), // std::make_tuple(320,1280,52,76,8,3,3,3), // std::make_tuple(1280,1280,52,76,8,3,3,3), // std::make_tuple(320,1280,104,152,8,3,3,3), From 89103a856c8437b668f75aae86fd736465bddfab Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 15:26:09 -0500 Subject: [PATCH 17/19] updated test cases using tensor shapes from sd.cpp wan video generation --- tests/test-backend-ops.cpp | 19 ++++++++++--------- tests/test-conv3d.cpp | 9 +++++++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f8a9b8f4b8dbe..49a4061f0708b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7375,10 +7375,11 @@ static std::vector> make_test_cases_perf() { } - for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + // for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (ggml_type kernel_type : {GGML_TYPE_F16}) { for (int N : {1}) { - for (int IC : {320, 640}) { - for (int OC : {320, 640}) { + for (int IC : {48, 320, 640, 1024}) { + for (int OC : {320, 640, 1024, 2048}) { for (int s0 : {1}) { for (int p1 : {1}) { for (int d2 : {1}) { @@ -7404,12 +7405,12 @@ static std::vector> make_test_cases_perf() { OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2, kernel_type)); - IW = 208; IH = 316; - test_cases.emplace_back(new test_conv_3d( - N, IC, ID, IH, IW, - OC, KD, KH, KW, - s0, s1, s2, p0, p1, p2, d0, d1, d2, - kernel_type)); + // IW = 208; IH = 316; + // test_cases.emplace_back(new test_conv_3d( + // N, IC, ID, IH, IW, + // OC, KD, KH, KW, + // s0, s1, s2, p0, p1, p2, d0, d1, d2, + // kernel_type)); } } } diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index e05bc50090356..a34169c36df62 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -358,7 +358,12 @@ int main(void) // std::make_tuple(320,1280,104,152,8,3,3,3), // std::make_tuple(1280,1280,104,152,8,3,3,3), // std::make_tuple(320,1280,208,304,4,3,3,3), - // std::make_tuple(640,1280,208,304,4,3,3,3), + std::make_tuple(1024,2048,30,52,3,3,3,3), + std::make_tuple(1024,2048,52,76,4,3,3,3), + std::make_tuple(1024,2048,52,76,6,3,3,3), + std::make_tuple(48,3072,64,64,9,2,2,1), + std::make_tuple(48,3072,64,64,17,2,2,1), + std::make_tuple(48,3072,64,64,33,2,2,1), }; int k = 0; @@ -380,7 +385,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, From 5e1352cb60dd4eddadd57ee35d88393ef61c0ce7 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 15:33:32 -0500 Subject: [PATCH 18/19] add a case havign memory access violation --- ggml/src/ggml-cuda/conv3d-implicit.cu | 1 - tests/test-conv3d.cpp | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index fd6e1d0b71386..31eaedc3dfe8e 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -1082,7 +1082,6 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); - conv3d_implicit_kernel <<>>(X_H, K_H, Y_D, P); diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index a34169c36df62..a7ad5ab3573b1 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -350,7 +350,7 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(1,2,16,32,4,3,3,3), + // std::make_tuple(1,2,16,32,4,3,3,3), // std::make_tuple(320,1280,26,38,8,3,3,3), // std::make_tuple(1280,1280,26,38,8,3,3,3), // std::make_tuple(320,1280,52,76,8,3,3,3), @@ -358,12 +358,13 @@ int main(void) // std::make_tuple(320,1280,104,152,8,3,3,3), // std::make_tuple(1280,1280,104,152,8,3,3,3), // std::make_tuple(320,1280,208,304,4,3,3,3), - std::make_tuple(1024,2048,30,52,3,3,3,3), - std::make_tuple(1024,2048,52,76,4,3,3,3), - std::make_tuple(1024,2048,52,76,6,3,3,3), - std::make_tuple(48,3072,64,64,9,2,2,1), - std::make_tuple(48,3072,64,64,17,2,2,1), - std::make_tuple(48,3072,64,64,33,2,2,1), + // std::make_tuple(1024,2048,30,52,3,3,3,3), + // std::make_tuple(1024,2048,52,76,4,3,3,3), + // std::make_tuple(1024,2048,52,76,6,3,3,3), + // std::make_tuple(48,3072,64,64,9,2,2,1), + // std::make_tuple(48,3072,64,64,17,2,2,1), + // std::make_tuple(48,3072,64,64,33,2,2,1), + std::make_tuple(320,320,104,158,8,3,3,3), }; int k = 0; From d6d24487c213a8742325621cf5371b699f83b9eb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 16:52:46 -0500 Subject: [PATCH 19/19] fixed a bug of not bound checking batch dimension --- ggml/src/ggml-cuda/conv3d-implicit.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index d7d8ef10868e6..c226351388df3 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -180,7 +180,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){ + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && + n < param.n && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ @@ -249,7 +250,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d - && curC < param.c && kidx < end_k){ + && n < param.n && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{