diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index d953cadd62dcf..4cc42e1674ccc 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -571,6 +571,10 @@ int main(int argc, char ** argv) { model_ttc = llama_init_ttc.model.get(); ctx_ttc = llama_init_ttc.context.get(); + if (model_ttc == nullptr || ctx_ttc == nullptr) { + return ENOENT; + } + const llama_vocab * vocab = llama_model_get_vocab(model_ttc); // TODO: refactor in a common struct @@ -586,6 +590,10 @@ int main(int argc, char ** argv) { model_cts = llama_init_cts.model.get(); ctx_cts = llama_init_cts.context.get(); + if (model_cts == nullptr || ctx_cts == nullptr) { + return ENOENT; + } + std::vector smpl(n_parallel); for (int i = 0; i < n_parallel; ++i) { params.sampling.no_perf = (i != 0); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c1c7498694beb..1e4c2422756ac 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -76,7 +76,11 @@ if (GGML_CCACHE) set(GGML_CCACHE_VARIANT sccache) endif() # TODO: should not be set globally - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}") + if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl") + else () + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}") + endif () set(ENV{CCACHE_SLOPPINESS} time_macros) message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.") else() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 75dc96b478655..2dbe835586d4c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont( const int ith = params->ith; // thread index const int nth = params->nth; // number of threads - // parallelize by elements - const int ne = ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = MIN(ie0 + dr, ne); + // parallelize by blocks + const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type); + const int dr = (nk + nth - 1) / nth; + const int k0 = dr * ith; + const int k1 = MIN(k0 + dr, nk); - if (ie0 < ie1) { + if (k0 < k1) { memcpy( - ((char *) dst->data + ie0*nb0), - ((char *) src0->data + ie0*nb0), - (ie1 - ie0) * nb0); + ((char *) dst->data + k0*nb0), + ((char *) src0->data + k0*nb0), + (k1 - k0) * nb0); } } @@ -4055,7 +4055,6 @@ static void ggml_compute_forward_dup_f32( static void ggml_compute_forward_dup_bytes( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); @@ -4069,10 +4068,10 @@ static void ggml_compute_forward_dup_bytes( } const size_t type_size = ggml_type_size(src0->type); + const int ith = params->ith; // thread index const int nth = params->nth; // number of threads - // parallelize by rows const int nr = ne01; // number of rows per thread @@ -4082,10 +4081,10 @@ static void ggml_compute_forward_dup_bytes( const int ir1 = MIN(ir0 + dr, nr); if (src0->type == dst->type && - ne00 == ne0 && + ggml_are_same_shape(src0, dst) && nb00 == type_size && nb0 == type_size) { // copy by rows - const size_t rs = ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ir0; i01 < ir1; i01++) { @@ -4140,17 +4139,20 @@ static void ggml_compute_forward_dup_bytes( } // dst counters - - int64_t i10 = 0; + int64_t k10 = 0; int64_t i11 = 0; int64_t i12 = 0; int64_t i13 = 0; + // number of blocks in a row + const int64_t nk00 = ne00 / ggml_blck_size(src0->type); + const int64_t nk0 = ne0 / ggml_blck_size(dst->type); + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; + k10 += nk00 * ir0; + while (k10 >= nk0) { + k10 -= nk0; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -4162,14 +4164,14 @@ static void ggml_compute_forward_dup_bytes( } } for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + for (int64_t k00 = 0; k00 < nk00; k00++) { + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); memcpy(dst_ptr, src0_ptr, type_size); - if (++i10 == ne0) { - i10 = 0; + if (++k10 == nk0) { + k10 = 0; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -4182,9 +4184,9 @@ static void ggml_compute_forward_dup_bytes( } } } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; + k10 += nk00 * (ne01 - ir1); + while (k10 >= nk0) { + k10 -= nk0; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -14308,7 +14310,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } // extra_buffer op? - if (ggml_cpu_extra_compute_forward(params, tensor)) return; + if (ggml_cpu_extra_compute_forward(params, tensor)) { + return; + } switch (tensor->op) { case GGML_OP_DUP: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e78205e5d53af..6b5cd32a4a541 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -41,14 +41,17 @@ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons -#define GGML_CUDA_CC_PASCAL 600 -#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define GGML_CUDA_CC_VOLTA 700 -#define GGML_CUDA_CC_TURING 750 -#define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_ADA_LOVELACE 890 -#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 - +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 +#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) + +// AMD // GCN/CNDA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue @@ -70,8 +73,17 @@ #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) -#define GGML_CUDA_CC_QY1 210 -#define GGML_CUDA_CC_QY2 220 +// Moore Threads +#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210) + +#define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD + +#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT) +#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { @@ -209,21 +221,21 @@ typedef float2 dfloat2; #define CP_ASYNC_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1) #define FLASH_ATTN_AVAILABLE -#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1) static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; } static bool fast_fp16_available(const int cc) { - return fp16_available(cc) && cc != 610; + return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fast_fp16_hardware_available(const int cc) { - return cc >= GGML_CUDA_CC_PASCAL && cc != 610; + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); } // Any FP16 tensor core instructions are available for ggml code. @@ -231,20 +243,20 @@ static bool fp16_mma_available(const int cc) { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) return false; #else - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || - GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || - GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; + return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. static bool new_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } static bool cp_async_available(const int cc) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 973541893ec21..8edc12649aa63 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -253,7 +253,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - if (cc >= GGML_CUDA_CC_OFFSET_AMD) { + if (GGML_CUDA_CC_IS_AMD(cc)) { #if defined(GGML_HIP_ROCWMMA_FATTN) if (fp16_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b783310ef7ba7..10d461b773a6a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -264,9 +264,9 @@ static ggml_cuda_device_info ggml_cuda_init() { #elif defined(GGML_USE_MUSA) // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs. info.devices[id].warp_size = 32; - // TODO: refine the .cc to reflect MUSA's actual CC capabilities info.devices[id].smpbo = prop.sharedMemPerBlockOptin; - info.devices[id].cc = 100*prop.major + 10*prop.minor; + info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #else @@ -1188,11 +1188,11 @@ static void ggml_cuda_op_mul_mat_cublas( // ldc == nrows of the matrix that cuBLAS writes into int64_t ldc = id == ctx.device ? ne0 : row_diff; - const int compute_capability = ggml_cuda_info().devices[id].cc; + const int cc = ggml_cuda_info().devices[id].cc; const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; - if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) { + if (((cc >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (GGML_CUDA_CC_IS_CDNA(compute_capability)) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 10f2ebb1cb43d..510c1e9b2aa9e 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -28,7 +28,7 @@ void ggml_cuda_op_mul_mat_q( // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && - cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11; + GGML_CUDA_CC_IS_NVIDIA(cc) && src1_ncols == ne11; const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; switch (src0->type) { @@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return true; #endif //GGML_CUDA_FORCE_MMQ - if (cc < GGML_CUDA_CC_OFFSET_AMD) { + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index f2aca1f2014e1..4ea8b8d4b1290 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -90,7 +90,7 @@ struct tile_x_sizes { static int get_mmq_x_max_host(const int cc) { return new_mma_available(cc) ? 128 : - ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? + ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc) ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; #else @@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() { } static int get_mmq_y_host(const int cc) { - return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : - (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64); + return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + ((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { @@ -2772,14 +2772,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shmem_limit_raised[id]) { CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); shmem_limit_raised[id] = true; } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) const int nty = (args.ne01 + mmq_y - 1) / mmq_y; const int ntx = (args.ne11 + mmq_x - 1) / mmq_x; @@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; - const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD; + const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc); int mmq_x_best = 0; int nparts_best = INT_MAX; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d450fe9a2f2f6..37fa8eec599a3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -149,6 +149,7 @@ class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; enum vk_device_architecture { OTHER, @@ -231,6 +232,7 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_add; bool subgroup_size_control; uint32_t subgroup_min_size; @@ -277,7 +279,7 @@ struct vk_device_struct { vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; @@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_add && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2281,13 +2289,21 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); @@ -2471,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceDriverProperties driver_props; vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; props2.pNext = &props3; props3.pNext = &subgroup_props; subgroup_props.pNext = &driver_props; - driver_props.pNext = &vk12_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; @@ -2541,6 +2559,9 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -4627,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t d_sz = sizeof(float) * d_ne; + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -4653,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); } static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index c813f14044eca..9c76437d9b0b9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,5 +1,10 @@ #version 450 +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 + #include "types.comp" #include "generic_unary_head.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 8835c442ecfd8..2a162a2c81543 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -82,8 +82,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]); - const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; return vec4(v0.x, v0.y, v1.x, v1.y); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index 9718a05e5adb2..8d01536fa69c0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -19,8 +19,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const float db = d * (0.5 + scale) * 0.25; const uint qh = data_a[ibi].qh[ib32]; - const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]); - const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]); + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; [[unroll]] for (uint l = 0; l < 2; ++l) { const uint8_t sign = sign16[l]; const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index af48f32902fe2..f021e40476199 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -21,7 +21,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, sum[j] = 0.0; } [[unroll]] for (uint l = 0; l < 4; ++l) { - const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]); + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 const uint sign = data_a[ibi].signs[4 * ib32 + l]; const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index 1cc4996d393a2..48376637fb3e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -12,6 +12,9 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + layout (push_constant) uniform parameter { uint ncols_x; @@ -37,25 +40,66 @@ void main() { const uint idst = channel*nrows_dst + row_dst; - tmp[tid] = 0.0f; + FLOAT_TYPE temp = 0.0f; - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; - if (col_x >= p.ncols_x) { - break; - } + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; - const uint row_y = col_x; + const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; + } } + tmp[tid] = temp; + // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp index 9b443807d8781..7aa070eebdf72 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -2,16 +2,25 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif -#define BLOCK_SIZE 32 #define FLOAT_TYPE float -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + layout (push_constant) uniform parameter { uint ncols_x; @@ -22,52 +31,124 @@ layout (push_constant) uniform parameter uint d_offset; } p; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; - const uint channel = gl_GlobalInvocationID.z; - const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; - tmp[tid] = FLOAT_TYPE(0.0f); + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); + } + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; - if (col_x >= p.ncols_x) { - break; - } + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { - // x is transposed and permuted - const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; - const uint row_y = col_x; + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); - // y is not transposed but permuted - const uint iy = channel*nrows_y + row_y; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); - } + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } - // dst is not transposed and not permuted - const uint idst = channel*nrows_dst + row_dst; + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } } barrier(); } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif if (tid == 0) { - dst[idst] = tmp[0]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 0d03411f24ca4..5a0054bac336c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -336,8 +336,8 @@ void main() { const uint iqs = idx & 0x07; const float d = float(data_a_packed16[ib].d); - const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]); - const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); @@ -544,7 +544,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -564,7 +564,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -586,7 +586,7 @@ void main() { const float db = d * 0.25 * (0.5 + scale); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -611,7 +611,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -631,7 +631,7 @@ void main() { const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index eb2ad63ff6bf0..1edb8267f1ebe 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -426,8 +426,9 @@ void process_shaders() { } } - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -445,6 +446,7 @@ void process_shaders() { for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cc48913d9789d..13cca7ab009bf 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1113,6 +1113,7 @@ class MODEL_TENSOR(IntEnum): ], MODEL_ARCH.GEMMA3: [ MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9debb56cc80d5..8664f8963cc18 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -778,6 +778,7 @@ static const std::map> LLM_TENSOR_N { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cd7e0a0c4dbf8..0ae754154b069 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -271,17 +271,30 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } - // add extra buffer types - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(cpu_dev, *extra_bufts); - ++extra_bufts; + bool has_gpu_device = false; + for (auto * dev : devices) { + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { + has_gpu_device = true; + break; + } + } + + // add extra buffer types, only if no GPU device is present + // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 + if (!has_gpu_device) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } } + } else { + LLAMA_LOG_WARN("%s: disabling extra buffer types (i.e. repacking) since a GPU device is available\n", __func__); } // add a host buffer type @@ -2329,7 +2342,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); @@ -2558,7 +2571,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -3215,16 +3233,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -3335,12 +3353,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); @@ -3370,7 +3388,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); const int time_mix_extra_dim = hparams.time_mix_extra_dim; @@ -3396,7 +3414,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); @@ -3405,9 +3423,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); // optional bias tensors - layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); @@ -3528,8 +3546,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); } - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); try { layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); @@ -3546,8 +3564,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d48cd21723155..ebc32d7919fe4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1463,11 +1463,13 @@ struct test_cpy : public test_case { const ggml_type type_src; const ggml_type type_dst; const std::array ne; - const std::array permute; + const std::array permute_src; + const std::array permute_dst; bool _src_use_permute; + bool _dst_use_permute; std::string vars() override { - return VARS_TO_STR4(type_src, type_dst, ne, permute); + return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst); } double max_nmse_err() override { @@ -1480,9 +1482,11 @@ struct test_cpy : public test_case { test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, std::array ne = {10, 10, 10, 1}, - std::array permute = {0, 0, 0, 0}) - : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute), - _src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} + std::array permute_src = {0, 0, 0, 0}, + std::array permute_dst = {0, 0, 0, 0}) + : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst), + _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0), + _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -1490,13 +1494,18 @@ struct test_cpy : public test_case { ggml_set_name(src, "src"); if (_src_use_permute) { - src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]); + src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]); ggml_set_name(src, "src_permuted"); } - ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne); + ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne); ggml_set_name(dst, "dst"); + if (_dst_use_permute) { + dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]); + ggml_set_name(dst, "dst_permuted"); + } + ggml_tensor * out = ggml_cpy(ctx, src, dst); ggml_set_name(out, "out"); @@ -1964,9 +1973,10 @@ struct test_mul_mat : public test_case { const std::array bs; // dims 3 and 4 const std::array nr; // repeat in dims 3 and 4 const std::array per; // permutation of dimensions + const bool v; // whether a is a non-contiguous view std::string vars() override { - return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per); + return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v); } double max_nmse_err() override { @@ -1986,8 +1996,9 @@ struct test_mul_mat : public test_case { int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, std::array nr = {2, 2}, - std::array per = {0, 1, 2, 3}) - : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {} + std::array per = {0, 1, 2, 3}, + bool v = false) + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) @@ -1997,6 +2008,7 @@ struct test_mul_mat : public test_case { const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3); if (npermuted > 0) { GGML_ASSERT(npermuted == 2); + GGML_ASSERT(!v); // not handled GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0); GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0); @@ -2020,7 +2032,13 @@ struct test_mul_mat : public test_case { ggml_set_name(a, "a_permuted"); ggml_set_name(b, "b_permuted"); } else { - a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + + if (v) { + a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]); + a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0); + } else { + a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + } b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); if (!ggml_is_quantized(type_a)) { if (bs[1] == 1 && nr[1] == 1) { @@ -3995,14 +4013,25 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim)); } - for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { + // same-type copy + for (ggml_type type : all_types) { + const auto nk = ggml_blck_size(type); + + for (int k = 1; k < 4; ++k) { + test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4})); + test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3})); + } + } + + for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) { for (ggml_type type_dst : all_types) { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4})); test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows } } - for (ggml_type type_dst : {GGML_TYPE_F32}) { - for (ggml_type type_src : all_types) { + for (ggml_type type_src : all_types) { + for (ggml_type type_dst : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4})); test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows } @@ -4176,6 +4205,17 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); + for (auto bs : {1,2,4,8}) { + for (auto nr : {1,4}) { + for (uint32_t m = 0; m < 2; ++m) { + for (uint32_t k = 0; k < 2; ++k) { + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true)); + } + } + } + } + // sycl backend will limit task global_range < MAX_INT // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion) // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -4444,6 +4484,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true)); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) {