Skip to content
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
8a58931
Add implicit GEMM convolution operation for 2D tensors in CUDA
bssrdf Sep 3, 2025
4d77287
Add implicit convolution support for 2D tensors in CPU and CUDA imple…
bssrdf Sep 3, 2025
3877608
fix passing param as reference
bssrdf Sep 3, 2025
6d84cbb
Fix parameter order in conv2d_implicit and add comprehensive test cas…
bssrdf Sep 3, 2025
5ffe97b
Fix boundary check in conv2d_implicit_kernel to include channel limits
bssrdf Sep 4, 2025
4b0f9d5
Refactor conv2d_implicit_kernel for improved readability and consiste…
bssrdf Sep 5, 2025
83a3b7d
Refactor conv2d_implicit_kernel for improved bitwise operations; add …
bssrdf Sep 6, 2025
735886b
merged with upstream master
bssrdf Sep 10, 2025
2ec76aa
Merge branch 'master' into conv2d-implicit
bssrdf Sep 11, 2025
53a2ccb
minor update and add direct conv in benchmarking
bssrdf Sep 25, 2025
c625544
minor updates
bssrdf Oct 8, 2025
0ca4358
reorder register tile loop
bssrdf Oct 8, 2025
16b0f0a
work in progress
bssrdf Oct 13, 2025
2237722
added block variants; to be debugged
bssrdf Oct 14, 2025
3e2f722
fixed missing dilation
bssrdf Oct 14, 2025
b70cca2
add support for both NCHW and NHWC layouts
bssrdf Oct 14, 2025
3f99818
unroll some loops
bssrdf Oct 15, 2025
ac77b8d
change padding size to 1; added padding to input smem
bssrdf Oct 15, 2025
6a1f8b4
change padding size back to 4
bssrdf Oct 15, 2025
15484c9
turn on tests for implicit conv2d
bssrdf Oct 18, 2025
f0a480c
WIP
bssrdf Oct 21, 2025
f931ad8
WIP
Oct 21, 2025
1b69ed4
WIP
Oct 21, 2025
215ebf6
WIP
bssrdf Oct 22, 2025
66f6d16
WIP
bssrdf Oct 23, 2025
2715341
WIP: output
bssrdf Oct 24, 2025
80a996c
WIP: tensore code compiled ok
bssrdf Oct 24, 2025
be25be8
WIP: debugging tensor core kernel
bssrdf Oct 24, 2025
6c90c20
WIP: bug fix
bssrdf Oct 24, 2025
24b5532
WIP: fixed another bug
bssrdf Oct 24, 2025
980ddc1
properly use __CUDA_ARCH__ to protect the tensor path
bssrdf Oct 25, 2025
c45df12
this case is broken; to be debugged
bssrdf Oct 25, 2025
610e41a
still debugging
bssrdf Oct 25, 2025
396f558
WIP: bug fix
bssrdf Oct 25, 2025
475f987
WIP: fixed another bug
bssrdf Oct 26, 2025
c68fe36
WIP: cleanup; enhanced test case
bssrdf Oct 26, 2025
3099078
WIP
bssrdf Oct 27, 2025
cc327f5
added a specialization for cuda copy op when tensor is transposed
Oct 27, 2025
a3784e1
WIP: debugging cpy transpose
bssrdf Oct 27, 2025
6d12288
WIP: fixed a bug in cpy transpos index computation
Oct 27, 2025
3ea524e
WIP: almost working
bssrdf Oct 28, 2025
75dde41
WIP: minor tweak
bssrdf Oct 28, 2025
4b1920e
reduced bank conflicts for output
bssrdf Oct 29, 2025
1e56825
switch to default conv2d interface
bssrdf Oct 29, 2025
2dfbbee
clean up
bssrdf Oct 29, 2025
55859a8
remove implicit op and related calls; replace conv_2d with conv_2d_im…
bssrdf Oct 30, 2025
a3b4d8d
clean up
bssrdf Oct 30, 2025
7013227
more clean up
bssrdf Oct 30, 2025
1f3d5eb
prevent CI compile failure
bssrdf Oct 30, 2025
c141ce3
make CI happy
bssrdf Oct 30, 2025
2b5351a
make CI happy
bssrdf Oct 30, 2025
c1f67c1
make CI happy
bssrdf Oct 30, 2025
417cfc3
added a test case to directly compare im2col and implicit gemm
bssrdf Oct 31, 2025
f95664c
make tensor core path available for cc 7.5 and above
bssrdf Nov 1, 2025
fa9e415
minor update of test case
bssrdf Nov 3, 2025
27881fb
fixes for CI
bssrdf Nov 4, 2025
8572313
remove trailing blank
bssrdf Nov 4, 2025
00a49c2
another CI fix
bssrdf Nov 4, 2025
275c08d
add more sd like test cases
bssrdf Nov 4, 2025
6f44f47
added split-k mode for skinny mnk shapes
bssrdf Nov 5, 2025
688de6d
fixed bug now split-k is working
bssrdf Nov 5, 2025
d9a4858
use a better criterian to use split-k
bssrdf Nov 5, 2025
09e3a5f
try to reduce index calculation
bssrdf Nov 6, 2025
68ccd2a
refactor cuda core code path
bssrdf Nov 6, 2025
311213d
make sure there are enough channels for split-k
bssrdf Nov 6, 2025
28b7094
Merge branch 'refactor-cuda-core-path' into conv2d-implicit
bssrdf Nov 6, 2025
ba70ad8
added test cases exactly replicating sdxl unet steps
bssrdf Nov 7, 2025
4e9ebe9
minor update
bssrdf Nov 7, 2025
df88b2c
trying to get rid of remaining bank conflicts; also fixed a bug for s…
bssrdf Nov 7, 2025
76885c7
WIP: debugging
Nov 7, 2025
949eca4
swizzling working, may still have room to optimize
bssrdf Nov 8, 2025
8809af7
now bank conflicts free and performance get a bit boosted too
bssrdf Nov 8, 2025
414bb8d
further reduce index swizzling computation cycles
bssrdf Nov 8, 2025
64ead3f
remove commented code
bssrdf Nov 8, 2025
9cbc099
broken for some test cases
bssrdf Nov 8, 2025
a1fb3c1
fixed a bug now split-k can choose a better split factor
bssrdf Nov 8, 2025
a3fb36f
make split-k condition check more robust
bssrdf Nov 8, 2025
6106e90
make CI happy
bssrdf Nov 9, 2025
a2db92f
make CI happy
bssrdf Nov 9, 2025
8e0e944
reduced uncoalesced global access in filter transpose
bssrdf Nov 9, 2025
5ed2c1b
reduce bank conflicts in filter transpose
bssrdf Nov 9, 2025
496c359
add loop unrolling
bssrdf Nov 9, 2025
1fdcb05
increase maximum split factor to 16; use better heuristics to choose …
bssrdf Nov 10, 2025
a660d4d
get rid of a convert unary kernel call and fuse the type cast into co…
bssrdf Nov 10, 2025
fac6f0a
add missing batch index bounds check
bssrdf Nov 11, 2025
c33e430
m16n8k16 mma works; to be cleaned up
bssrdf Nov 12, 2025
ea438d8
trying to reduce integer ops; simply code
bssrdf Nov 12, 2025
9f498d2
only enable m16n8k16 on ampere or above
bssrdf Nov 12, 2025
0939511
change mac loop to match cutlass
bssrdf Nov 13, 2025
8bfb7ed
restore smem pointer at teh end of evry rs loop
Nov 13, 2025
63c53fe
WIP: move rs loop into block-k-loop following cutlass
bssrdf Nov 13, 2025
7d99222
WIP: debugging
bssrdf Nov 14, 2025
b015e4b
WIP: fixed bugs now results are correct
bssrdf Nov 14, 2025
0cb1ff4
move some register to const memory space
bssrdf Nov 14, 2025
b4530b4
disable m16n8k16 mma for ampere for now
bssrdf Nov 14, 2025
ecbbdb6
reducing integer ops
bssrdf Nov 14, 2025
e4fbece
various small optimizations
bssrdf Nov 14, 2025
11bd980
add/fix GGML_UNUSED
Nov 14, 2025
378bb83
WIP: adding cp.async calls
bssrdf Nov 14, 2025
dbeb6ce
WIP: debugging
bssrdf Nov 15, 2025
e10b495
add the missing guard
bssrdf Nov 15, 2025
e489dd2
WIP
bssrdf Nov 15, 2025
fa7dd68
not working properly for channel numbers of 32, 48, 96 etc., ok for 6…
bssrdf Nov 15, 2025
3591e83
the special filter transpose NCHW2NHWC is broken, disable it and use …
bssrdf Nov 16, 2025
721fa41
restore split-k for small inputs
bssrdf Nov 16, 2025
bccd869
fixed a bug in the special filter transpose NCHW2NHWC; still failing…
bssrdf Nov 16, 2025
febee58
fixed anotehr bug in the special filter transpose NCHW2NHWC
bssrdf Nov 16, 2025
f2187bb
added a few edge test cases
bssrdf Nov 16, 2025
f54cd74
due to cp.async, only support filter size <= 32
bssrdf Nov 16, 2025
775e48a
remove some repeated index computation; various code/comments clean up
bssrdf Nov 17, 2025
3e69104
minor tweak filter tranpose
bssrdf Nov 17, 2025
9bb5eb3
tuned block dimensions for filter tranpose
bssrdf Nov 17, 2025
5fbdefd
use fastdiv in filter transpose
bssrdf Nov 17, 2025
5e49125
make CI happy
bssrdf Nov 17, 2025
ba754ce
remove trailing blanks
Nov 17, 2025
7344456
further reduce repeated index comutations
bssrdf Nov 18, 2025
e760cd4
fix CI
bssrdf Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,051 changes: 1,051 additions & 0 deletions ggml/src/ggml-cuda/conv2d-implicit.cu

Large diffs are not rendered by default.

347 changes: 347 additions & 0 deletions ggml/src/ggml-cuda/conv2d-implicit.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
#pragma once
#include "common.cuh"

typedef struct{
unsigned int n; //batch size
unsigned int c; //number if channels
unsigned int h; //height
unsigned int w; //width
unsigned int k; //number of filters
unsigned int r; //filter height
unsigned int s; //filter width
unsigned int u; //stride height
unsigned int v; //stride width
unsigned int p; //padding height
unsigned int q; //padding width
unsigned int d_h; //dilation height
unsigned int d_w; //dilation width
unsigned int Oh; //output height
unsigned int Ow; //output width
uint3 SC_fastdiv;
uint3 OW_fastdiv;
uint3 C_fastdiv;
uint3 RS_fastdiv;
uint3 S_fastdiv;
uint3 OHOW_fastdiv;
} param_t;


// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleB(
const half* src,
half* dst,
const unsigned int src_stride,
param_t param
){
#if __CUDA_ARCH__ >= GGML_CUDA_TURING

constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;

float4* dst_float4 = reinterpret_cast<float4*>(dst);

// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;

// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //

#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst);
GGML_UNUSED(src_stride);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}


// this is a special case of the above for when TILE_COLS == 32
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleA(
const half* src,
half* dst,
// const unsigned int src_stride,
const unsigned int inChannelOffset,
param_t param
)
{
#if __CUDA_ARCH__ >= GGML_CUDA_TURING

constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;

float4* dst_float4 = reinterpret_cast<float4*>(dst);

// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;

// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;


#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst);
GGML_UNUSED(inChannelOffset);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}

template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadA(
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
// const unsigned int src_stride,
const unsigned int block_k,
const unsigned int inChannelOffset,
param_t param
){
#if __CUDA_ARCH__ >= GGML_CUDA_TURING

// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);

// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;

// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;

// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);

#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst_reg);
GGML_UNUSED(block_k);
GGML_UNUSED(inChannelOffset);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}


template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadB(
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int block_k,
const unsigned int src_stride,
param_t param
){
#if __CUDA_ARCH__ >= GGML_CUDA_TURING

// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);

// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;

// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;

// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);

const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //

#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst_reg);
GGML_UNUSED(block_k);
GGML_UNUSED(src_stride);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}


// same as above but without the swizzle

// this is a special case of the above for when TILE_COLS == 32
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpySwizzleStore(
const float4 (&src_reg)[ELEMENTS_PER_THREAD],
half* dst
)
{
#if __CUDA_ARCH__ >= GGML_CUDA_TURING

constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;

// reinterpret input/output as float4
float4* dst_float4 = reinterpret_cast<float4*>(dst);

// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);

// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;

// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;

// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);

#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
dst_float4[dst_index] = src_reg[i];
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src_reg);
GGML_UNUSED(dst);
NO_DEVICE_CODE;
#endif
}

__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
uint32_t address;
asm("{\n\t"
" .reg .u64 u64addr;\n\t"
" cvta.to.shared.u64 u64addr, %1;\n\t"
" cvt.u32.u64 %0, u64addr;\n\t"
"}"
: "=r"(address)
: "l"(pointer));
return address;
}


#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-implicit.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
Expand Down Expand Up @@ -2461,7 +2462,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_im2col_3d(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
ggml_cuda_op_conv2d_implicit(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS)
endif()
llama_build_and_test(test-gguf.cpp)
llama_build_and_test(test-backend-ops.cpp)
llama_build_and_test(test-conv2d.cpp)

llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
llama_build_and_test(test-autorelease.cpp LABEL "model")
Expand Down
Loading
Loading