diff --git a/ggml/src/ggml-cuda/conv2d-mm.cu b/ggml/src/ggml-cuda/conv2d-mm.cu index 7de78b4372eda..0d46bb16897e7 100644 --- a/ggml/src/ggml-cuda/conv2d-mm.cu +++ b/ggml/src/ggml-cuda/conv2d-mm.cu @@ -1,4 +1,5 @@ #include "conv2d-mm.cuh" +#include "convert.cuh" #include @@ -13,6 +14,8 @@ #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +typedef uint32_t uint; + uint32_t ceil_div(uint32_t M, uint32_t N); int get_sm_count(); @@ -51,29 +54,20 @@ __align__(16) struct Params { uint32_t nb2; uint32_t nb3; - uint32_t KWmp; - uint32_t KWL; - uint32_t KWKHmp; - uint32_t KWKHL; - uint32_t OWmp; - uint32_t OWL; - uint32_t OWOHmp; - uint32_t OWOHL; + uint3 KW_fastdiv; + uint3 KWKH_fastdiv; + uint3 OW_fastdiv; + uint3 OWOH_fastdiv; }; __constant__ __device__ Params dp; -// see init_fastdiv_values in ggml-vulkan.cpp -__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) { - return (__umulhi(n, mp) + n) >> L; -} - // --> conv_2d kernel modified to function as a matmul -template +template __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, uint NPQ, uint CRS, - const float * knl_data, + const T * knl_data, const float * src_data, float * dst_data) { // Each block computes a tile of the result of size BS_K*BS_NPQ @@ -98,7 +92,8 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, const uint T_y = threadIdx.x / NT_x; const uint T_x = threadIdx.x % NT_x; - __shared__ float Ash[BS_K * BS_CRS]; + // __shared__ float Ash[BS_K * BS_CRS]; + __shared__ T Ash[BS_K * BS_CRS]; __shared__ float Bsh[BS_CRS * BS_NPQ]; const uint Ar = threadIdx.x / BS_CRS; @@ -135,10 +130,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, #else uint32_t CRS_idx_a = idx_CRS + Ac; //Global CRS_idx (column index of A) //uint32_t Cin_idx_a = CRS_idx_a / (dp.KW*dp.KH); - uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKH_fastdiv); // divide by (p.KW * p.KH); / (p.KW * p.KH); uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * dp.KW * dp.KH; //uint32_t KH_idx_a = (CRS_idx_a - Cin_idx_a*dp.KW*dp.KH) / dp.KW; - uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW; + uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KW_fastdiv); // divide by p.KW; //uint32_t KW_idx_a = CRS_idx_a - Cin_idx_a*dp.KW*dp.KH - KH_idx_a*dp.KW; // unused #endif @@ -148,9 +143,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, // General addressing (does not assume contiguity) //const uint32_t knl_idx = KW_idx_a + KH_idx_a*dp.nb01 + Cin_idx_a*dp.nb02 + K_idx_a*dp.nb03; // Contiguous addressing - float val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)]; + T val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)]; if (CRS_idx_a >= CRS || K_idx_a >= K) { - val = 0.0; + val = (T)0.0; } #ifdef A_TRANS @@ -173,10 +168,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, // Compute indices for N, OH, OW from NPQ_idx const uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + Bc; /* Global NPQ index (column index of B) */ //const uint32_t N_idx = NPQ_idx / (dp.OH*dp.OW); - uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW; + uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOH_fastdiv); // divide by p.OH * p.OW; uint32_t NPQ_remainder = NPQ_idx - N_idx * dp.OH * dp.OW; //const uint32_t OH_idx = (NPQ_idx - N_idx*dp.OH*dp.OW) / dp.OW; - uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OWmp, dp.OWL); // divide by p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OW_fastdiv); // divide by p.OW; const uint32_t OW_idx = NPQ_idx - N_idx * dp.OH * dp.OW - OH_idx * dp.OW; #ifdef USE_COLLECTIVES @@ -188,10 +183,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, // Compute indices KH, KW, Cin from CRS_idx uint32_t CRS_idx_b = idx_CRS + r_offset + Br; //uint32_t Cin_idx_b = CRS_idx_b / (dp.KW*dp.KH); - uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKH_fastdiv); // divide by (p.KW * p.KH); / (p.KW * p.KH); uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH; //uint32_t KH_idx_b = (CRS_idx_b - Cin_idx_b*dp.KW*dp.KH) / dp.KW; - uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW; + uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KW_fastdiv); // divide by p.KW; uint32_t KW_idx_b = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH - KH_idx_b * dp.KW; #endif @@ -235,9 +230,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, # else uint32_t col_offset = (T_y * TS_K + T_ly); # endif - regA[T_ly] = Ash[CRS_lidx * BS_K + col_offset]; + regA[T_ly] = ggml_cuda_cast(Ash[CRS_lidx * BS_K + col_offset]); #else - regA[T_ly] = Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]; + regA[T_ly] = ggml_cuda_cast(Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]); #endif } for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) { @@ -267,9 +262,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, const uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; const uint32_t NPQ_idx_c = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; //const uint32_t N_idx_c = NPQ_idx_c / (dp.OH*dp.OW); - const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW; + const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOH_fastdiv); // divide by p.OH * p.OW; //const uint32_t OH_idx_c = (NPQ_idx_c - N_idx_c*dp.OH*dp.OW) / dp.OW; - const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OWmp, dp.OWL); // divide by p.OW; + const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OW_fastdiv); // divide by p.OW; const uint32_t OW_idx_c = NPQ_idx_c - N_idx_c * dp.OH * dp.OW - OH_idx_c * dp.OW; const uint32_t dst_idx = OW_idx_c + OH_idx_c * dp.nb1 + K_idx * dp.nb2 + N_idx_c * dp.nb3; if (K_idx < K && NPQ_idx_c < NPQ) { @@ -279,22 +274,6 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K, } } -// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. -// Precompute mp (m' in the paper) and L such that division -// can be computed using a multiply (high 32b of 64b result) -// and a shift: -// -// n/d = (mulhi(n, mp) + n) >> L; -static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) { - // compute L = ceil(log2(d)); - L = 0; - while (L < 32 && (uint32_t{ 1 } << L) < d) { - L++; - } - - mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); -} - constexpr int conv_shapes[][NUM_VARIANTS] = { { 128, 64, 32 }, // BS_K { 16, 32, 16 }, // BS_CRS @@ -340,19 +319,26 @@ void ggml_cuda_op_conv_2d_variant(ggml_backend_cuda_context & ctx, uint32_t NB_K = CEIL_DIV(p.Cout, BS_K); uint32_t NB_NPQ = CEIL_DIV(NPQ, BS_NPQ); + cudaStream_t stream = ctx.stream(); cudaMemcpyToSymbol(dp, &p, sizeof(Params)); + // cudaMemcpyToSymbolAsync(dp, &p, sizeof(Params), 0, cudaMemcpyHostToDevice, stream); // Kernel arguments - float * src0_data = (float *) src0->data; float * src1_data = (float *) src1->data; float * dst_data = (float *) dst->data; dim3 gridDim(NB_K, NB_NPQ); dim3 blockDim(WG_SIZE); - cudaStream_t stream = ctx.stream(); - - mm + if(src0->type == GGML_TYPE_F16) { + half *src0_data = (half *) src0->data; + mm + <<>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data); + } else { + float *src0_data = (float *) src0->data; + mm <<>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data); + } + } void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -372,13 +358,13 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(half)); GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float)); @@ -413,10 +399,10 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) p.nb2 = static_cast(nb2 / nb0); p.nb3 = static_cast(nb3 / nb0); - init_fastdiv_values(p.KW, p.KWmp, p.KWL); - init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL); - init_fastdiv_values(p.OW, p.OWmp, p.OWL); - init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL); + p.KW_fastdiv = init_fastdiv_values(p.KW); + p.KWKH_fastdiv = init_fastdiv_values(p.KW * p.KH); + p.OW_fastdiv = init_fastdiv_values(p.OW); + p.OWOH_fastdiv = init_fastdiv_values(p.OW * p.OH); GGML_ASSERT(ne03 == ne2); GGML_ASSERT(ne02 == ne12); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a9fe51875af46..bc944170600a2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2463,7 +2463,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg break; case GGML_OP_CONV_2D: if (!getenv("GGML_CUDA_USE_LEGACY_CONV") && - (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && + (dst->src[1]->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32)) { ggml_cuda_op_conv2d_mm(ctx, dst); } else { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 893c2af3137a8..12718ce18806c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -6615,14 +6616,66 @@ static std::vector> make_test_cases_perf() { { 16, 3, 512, 128, 8 }, }; - for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (auto act_case : cases) { - // Direct CONV_2D - test_cases.emplace_back(new test_conv_2d( - { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, - { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, - kernel_type, 1, 1, 0, 0, 1, 1, false)); - } + // for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + // for (auto act_case : cases) { + // // Direct CONV_2D + // test_cases.emplace_back(new test_conv_2d( + // { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + // { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, + // kernel_type, 1, 1, 0, 0, 1, 1, false)); + // } + // } + + // Stable-diffusion layers + std::map idx_sd{ + { "iw", 0 }, + { "ih", 1 }, + { "kw", 2 }, + { "kh", 3 }, + { "Cout", 4 }, + { "Cin", 5 }, + { "B", 6 }, + }; + + // Input image size + uint32_t w = 768; + uint32_t h = 1024; + + // Number of filters (base) + uint32_t Cout_b = 128; + uint32_t Cin_b = 128; + + std::vector> cases_sd = { + { w / 8, h / 8, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x10 (called 10 times) + { w / 4, h / 4, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x7 + { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, // x5 + { w, h, 3, 3, Cout_b, Cin_b, 1 }, // x5 + { w / 8, h / 8, 1, 1, Cout_b * 4, Cin_b * 4, 1 }, // x4 + { w / 8, h / 8, 1, 1, 4, 4, 1 }, + { w / 8, h / 8, 3, 3, Cout_b * 4, 4, 1 }, + + { w / 2, h / 2, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, + { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 4, 1 }, + { w / 2, h / 2, 1, 1, Cout_b * 2, Cin_b * 4, 1 }, + + { w, h, 3, 3, Cout_b, Cin_b * 2, 1 }, + { w, h, 1, 1, Cout_b, Cin_b * 2, 1 }, + { w, h, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, + + { w, h, 3, 3, 3, Cin_b, 1 }, + }; + + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); } test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));