From 36955c35b71581c8f72c659755409a7a456b106c Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 02:13:16 -0500 Subject: [PATCH 1/8] initial commit for branch glm45v --- convert_hf_to_gguf.py | 29 +++++++++++++++++++++++++++++ src/llama-arch.h | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8c5132193e0e0..36278866da005 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9219,6 +9219,35 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Glm4vMoeForConditionalGeneration") +class GLM4V_MoE(MmprojModel): + # + # the HF model's type is `glm4v_moe`. internally, it consists of two models: + # - `glm4v_moe_text` + # + main text model + # + tensor names start with "model.language_model." + # + "2D-RoPE" (aKa Roformer) w/ embeddings dynamically adapted via bicubic interpolation + # - `glm4v_moe` + # + vision adapter (ViT) + # + tensor names start with "model.visual." + # + "3D-RoPE" (without the interpolation mentioned above) + # + # other notable quirks include: + # - has MTP layer (need to keep these tensors - same as GLM-4.5-Air) + # - RoPE theta value (θ): use 10k rather than 100k for GLM-4.5-Air + # - the model's vision supports video input, but this is not implemented here + # + # for more info, refer to: + # - reference impl : https://github.com/huggingface/transformers/tree/main/src/transformers/models/glm4v_moe + # - HF model card : https://huggingface.co/zai-org/GLM-4.5V + # - arXiv paper (model) : https://arxiv.org/abs/2507.01006 + # - arXiv paper (orig. ViT) : https://arxiv.org/abs/2411.14402 + # + # TODO: the model's tokenizer has video-related special tokens - deal with these (??) + # + pass + + ###### CONVERSION LOGIC ###### diff --git a/src/llama-arch.h b/src/llama-arch.h index c3ae71655b17b..45670a189d0a5 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM4V_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -122,7 +123,6 @@ enum llm_kv { LLM_KV_GENERAL_LICENSE, LLM_KV_GENERAL_SOURCE_URL, LLM_KV_GENERAL_SOURCE_HF_REPO, - LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, From e8831e0cd334fdebc52a0c259d9ea7e23061ccb3 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 13:14:39 -0500 Subject: [PATCH 2/8] latest from upstream (#4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cuda : remove legacy copy-op pointer indirection code (#16485) * remove legacy copy-op pointer indirection code * further removal of copy-op indirection code * renamed check_node_graph_compatibility_and_refresh_copy_ops function * CUDA: add fp kernel for larger batch size MoE (#16512) * CUDA: kernel for larger batch sizes for MoE * WIP * WIP * WIP * WIP * WIP * WIP * fixup * tests * Move mmq_ids_helper to mmid * cleanup * Remove redundant checks * CUDA: use fastdiv + ggml_cuda_mad for mmvf (#16557) * CUDA: use fastdiv + ggml_cuda_mad for mmvf * use bf16 directly + fix formatting * Add exception for HIP code * CUDA: enable FA for FP32 KV cache (#16546) * vulkan: Improve build time for MSVC (#16545) Enable CMP0147 so custom build steps (invoking vulkan-shader-gen) are run in parallel. Enable /MP so source files are compiled in parallel. * vulkan: Support FA with K/V in F32 (#16543) * CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (#16577) * vulkan: Add ACC_TYPE_VEC2 implementation (#16203) Signed-off-by: Stefan Savic Co-authored-by: Stefan Savic * metal : avoid using Metal's gpuAddress property (#16576) * metal : avoid using Metal's gpuAddress property * metal : fix rope kernels buffer check --------- Signed-off-by: Stefan Savic Co-authored-by: Anav Prasad Co-authored-by: Aman Gupta Co-authored-by: Johannes Gäßler Co-authored-by: Jeff Bolz Co-authored-by: SavicStefan <50296686+SavicStefan@users.noreply.github.com> Co-authored-by: Stefan Savic Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/common.cuh | 7 - ggml/src/ggml-cuda/cpy.cu | 218 +++-------- ggml/src/ggml-cuda/cpy.cuh | 6 +- ggml/src/ggml-cuda/fattn-vec.cuh | 9 +- ggml/src/ggml-cuda/fattn.cu | 19 +- ggml/src/ggml-cuda/ggml-cuda.cu | 35 +- ggml/src/ggml-cuda/mmf.cu | 46 ++- ggml/src/ggml-cuda/mmf.cuh | 344 ++++++++++++++++-- ggml/src/ggml-cuda/mmid.cu | 164 +++++++++ ggml/src/ggml-cuda/mmid.cuh | 5 + ggml/src/ggml-cuda/mmq.cu | 169 +-------- ggml/src/ggml-cuda/mmvf.cu | 72 ++-- ggml/src/ggml-metal/ggml-metal-device.m | 24 +- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 8 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- ggml/src/ggml-vulkan/CMakeLists.txt | 9 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 +- .../vulkan-shaders/dequant_funcs_cm2.glsl | 14 + .../vulkan-shaders/flash_attn_base.glsl | 20 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 50 ++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 +- tests/test-backend-ops.cpp | 11 +- 24 files changed, 761 insertions(+), 496 deletions(-) create mode 100644 ggml/src/ggml-cuda/mmid.cu create mode 100644 ggml/src/ggml-cuda/mmid.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e0abde5427c83..41ff89c4d6922 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -944,13 +944,6 @@ struct ggml_cuda_graph { bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; std::vector ggml_graph_properties; - bool use_cpy_indirection = false; - std::vector cpy_dest_ptrs; - char ** dest_ptrs_d; - int dest_ptrs_size = 0; - // Index to allow each cpy kernel to be aware of it's position within the graph - // relative to other cpy nodes. - int graph_cpynode_index = -1; #endif }; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 746f43966b84c..12d5bf7763c38 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -8,18 +8,16 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); template -static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int } template -static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int cpy_blck(cx + x_offset, cdst + dst_offset); } -// Copy destination pointers to GPU to be available when pointer indirection is in use - -void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers - CUDA_CHECK(cudaStreamSynchronize(stream)); - if (cuda_graph->dest_ptrs_d != nullptr) { - CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); - } - CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); - cuda_graph->dest_ptrs_size = host_dest_ptrs_size; - } - // copy destination pointers to GPU - CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); - cuda_graph->graph_cpynode_index = 0; // reset index -#else - GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream); -#endif -} - template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q8_0_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_0_f32_cuda( @@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_1_f32_cuda( @@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_0_f32_cuda( @@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_1_f32_cuda( @@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - char ** dest_ptrs_d = nullptr; - int graph_cpynode_index = -1; -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { - dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; - graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; - } -#else - GGML_UNUSED(disable_indirection_for_this_node); -#endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) @@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - if (src0->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } else { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); - } + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { - ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; - } -#else - GGML_UNUSED(disable_indirection_for_this_node); -#endif - } void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - bool disable_indirection = true; - ggml_cuda_cpy(ctx, src0, dst, disable_indirection); -} - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { - if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - // Prioritize CUDA graph compatibility over direct memory copy optimization. - // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. - if (src0->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else { - return nullptr; - } - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK4_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK4_1>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK5_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK5_1>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else { - GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - } + ggml_cuda_cpy(ctx, src0, dst); } diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 0bd3c0c6f8c27..a7a87d8fcfb7e 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,10 +2,6 @@ #define CUDA_CPY_BLOCK_SIZE 64 -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); - -void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 89ab0f1638bf7..e1838fddedc6d 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nwarps = nthreads / WARP_SIZE; fattn_kernel_t fattn_kernel = flash_attn_ext_vec; - constexpr bool need_f16_K = false; - constexpr bool need_f16_V = false; + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; constexpr size_t nbytes_shared = 0; launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } @@ -526,11 +526,6 @@ template void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index fe970adaecef3..7dee032c29137 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg } } -#define FATTN_VEC_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ - return; \ - } \ +#define FATTN_VEC_CASE(D, type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ + ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ + return; \ + } \ + } \ #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ FATTN_VEC_CASE( 64, type_K, type_V) \ @@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS switch (K->type) { + case GGML_TYPE_F32: case GGML_TYPE_F16: break; case GGML_TYPE_Q4_1: @@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // If Turing tensor cores available, use them: if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { - if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { return BEST_FATTN_KERNEL_VEC; } @@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { - if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (Q->ne[1] == 1) { if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_VEC; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 856e9de2e1115..da312992c8039 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, +static bool check_node_graph_compatibility(ggml_cgraph * cgraph, bool use_cuda_graph) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; @@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_CPY) { - - // Store the pointers which are updated for each token, such that these can be sent - // to the device and accessed using indirection from CUDA graph - cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); - - // store a pointer to each copy op CUDA kernel to identify it later - void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); - if (!ptr) { - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); -#endif - } - } - if (!use_cuda_graph) { break; } } - if (use_cuda_graph) { - cuda_ctx->cuda_graph->use_cpy_indirection = true; - // copy pointers to GPU so they can be accessed via indirection within CUDA graph - ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); - } - return use_cuda_graph; } @@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) { return false; } @@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW ) { return false; @@ -2901,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } //if rms norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } @@ -3120,7 +3095,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, if (use_cuda_graph) { cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if (use_cuda_graph && cuda_graph_update_required) { @@ -3147,10 +3122,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - if (!use_cuda_graph) { - cuda_ctx->cuda_graph->use_cpy_indirection = false; - } - #else bool use_cuda_graph = false; bool cuda_graph_update_required = false; diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 599e085ee91b7..9e2aaf52d6cce 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,5 +1,7 @@ #include "ggml.h" #include "mmf.cuh" +#include "mmid.cuh" + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mmf_ids_data ids_info{}; + mmf_ids_data * ids_info_ptr = nullptr; + ggml_cuda_pool_alloc ids_src_compact_dev; + ggml_cuda_pool_alloc ids_dst_compact_dev; + ggml_cuda_pool_alloc expert_bounds_dev; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_dst = ids ? ne1 : ne2; @@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr nchannels_y = ids->ne[0]; } + if (ids && ncols_dst > 16) { + const int64_t n_expert_used = ids->ne[0]; + const int64_t n_experts = ne02; + const int64_t n_tokens = ne12; + const int64_t ne_get_rows = n_tokens * n_expert_used; + + ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows); + ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows); + expert_bounds_dev.alloc(ctx.pool(), n_experts + 1); + + const int si1 = static_cast(ids_s1); + const int sis1 = static_cast(src1->nb[2] / src1->nb[1]); + + GGML_ASSERT(sis1 > 0); + + ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), + static_cast(n_experts), static_cast(n_tokens), static_cast(n_expert_used), static_cast(ne11), si1, sis1, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + ids_info.ids_src_compact = ids_src_compact_dev.get(); + ids_info.ids_dst_compact = ids_dst_compact_dev.get(); + ids_info.expert_bounds_dev = expert_bounds_dev.get(); + ids_info.n_experts = static_cast(n_experts); + ids_info.sis1 = sis1; + ids_info_ptr = &ids_info; + } + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; @@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; @@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; @@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } if (mul_mat_id) { - if (type == GGML_TYPE_F32 && src1_ncols > 32) { + if (src0_ne[1] <= 1024 && src1_ncols > 512) { return false; - } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) { + } else if(src0_ne[1] > 1024 && src1_ncols > 128) { return false; } } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index a6c3adfcf1704..49d5295be0ea0 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,14 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +struct mmf_ids_data { + const int32_t * ids_src_compact = nullptr; + const int32_t * ids_dst_compact = nullptr; + const int32_t * expert_bounds_dev = nullptr; + int n_experts = 0; + int sis1 = 0; +}; + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); @@ -224,6 +232,250 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } + +//This kernel is for larger batch sizes of mul_mat_id +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = blockIdx.y; + const int expert_start = expert_bounds[expert_idx]; + const int expert_end = expert_bounds[expert_idx + 1]; + const int ncols_expert = expert_end - expert_start; + + const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; + const int tile_idx = blockIdx.z; + if (tile_idx >= tiles_for_expert) { + return; + } + + const int col_base = tile_idx * cols_per_block; + + GGML_UNUSED(channel_ratio); + + const int channel_x = expert_idx; + const int sample_dst = 0; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row; + y += int64_t(sample_y) *stride_sample_y; + dst += int64_t(sample_dst)*stride_sample_dst; + + const int32_t * ids_src_expert = ids_src_compact + expert_start; + const int32_t * ids_dst_expert = ids_dst_compact + expert_start; + + extern __shared__ char data_mmv[]; + char * compute_base = data_mmv; + + //const float2 * y2 = (const float2 *) y; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + + if constexpr (std::is_same_v) { + float vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float val = 0.0f; + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + val = y[channel*stride_channel_y + token*stride_col_y + col]; + } + } + vals[j0] = val; + } + }; + + gather_tile(0, vals_buf[0]); + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0]; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { + float2 vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float2 *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; + } + } + vals[j0] = tmp; + } + }; + + if (ntB > 0) { + gather_tile(0, vals_buf[0]); + } + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const float2 tmp = vals_buf[curr_buf][j0]; + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + const int global_j = col_base + j; + if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) { + const int dst_entry = ids_dst_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd); + const int token = (int) qrm.x; + if (token < ncols_dst_total) { + const int slot = (int) qrm.y; + dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids( const int64_t stride_col_id, const int64_t stride_row_id, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { - if (ids) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream, + const mmf_ids_data * ids_data) { + const bool has_ids_data = ids_data && ids_data->ids_src_compact; + + // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16) + // we prefer the normal mul_mat_f path with has_ids=true. + if (has_ids_data && ncols_dst > 16) { + const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); + if (max_tiles == 0) { + return; + } + dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); + + const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); + const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); + + mul_mat_f_ids<<>> + (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, + ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + sis1_fd, nch_fd); + } else if (ids) { const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; + mul_mat_f<<>> - (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { @@ -258,7 +532,7 @@ void mul_mat_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; @@ -290,7 +564,7 @@ void mul_mat_f_cuda( const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; - const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); const dim3 block_dims(warp_size, nwarps_best, 1); @@ -300,49 +574,57 @@ void mul_mat_f_cuda( mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 2: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 3: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 4: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 5: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 6: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 7: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 8: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; @@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block( case 1: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ - cudaStream_t stream); + cudaStream_t stream, const mmf_ids_data * ids_data); #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ diff --git a/ggml/src/ggml-cuda/mmid.cu b/ggml/src/ggml-cuda/mmid.cu new file mode 100644 index 0000000000000..3c61e4595a7b1 --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cu @@ -0,0 +1,164 @@ +#include "common.cuh" +#include "mmid.cuh" + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mm_ids_helper_store { + uint32_t data; + + __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mm_ids_helper[]; + mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mm_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= static_cast(offset)) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mm_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < static_cast(gridDim.x) - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mm_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + switch (n_expert_used) { + case 2: + launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 4: + launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 6: + launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 8: + launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 16: + launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 32: + launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + default: + launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + } +} diff --git a/ggml/src/ggml-cuda/mmid.cuh b/ggml/src/ggml-cuda/mmid.cuh new file mode 100644 index 0000000000000..ac090aea9ea1a --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cuh @@ -0,0 +1,5 @@ +#pragma once + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 12bdc629bd6b2..a2c8760abea93 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,141 +1,6 @@ #include "mmq.cuh" #include "quantize.cuh" - -#include - -// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. -struct mmq_ids_helper_store { - uint32_t data; - - __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { - data = (it & 0x003FFFFF) | (iex_used << 22); - } - - __device__ uint32_t it() const { - return data & 0x003FFFFF; - } - - __device__ uint32_t iex_used() const { - return data >> 22; - } -}; -static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); - -// Helper function for mul_mat_id, converts ids to a more convenient format. -// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. -// ids_dst describes the same mapping but for the dst tensor. -// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. -template -__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) -static __global__ void mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; - const int expert = blockIdx.x; - - extern __shared__ char data_mmq_ids_helper[]; - mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; - - int nex_prev = 0; // Number of columns for experts with a lower index. - int it_compact = 0; // Running index for the compact slice of this expert. - - if constexpr (n_expert_used_template == 0) { - // Generic implementation: - for (int it = 0; it < n_tokens; ++it) { - int iex_used = -1; // The index at which the expert is used, if any. - for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { - const int expert_used = ids[it*si1 + iex]; - nex_prev += expert_used < expert; - if (expert_used == expert) { - iex_used = iex; - } - } - - if (iex_used != -1) { - store[it_compact] = mmq_ids_helper_store(it, iex_used); - } - - if (warp_reduce_any(iex_used != -1)) { - it_compact++; - } - } - } else { - // Implementation optimized for specific numbers of experts used: - static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); - const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. - for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { - const int it = it0 + threadIdx.x / neu_padded; - - const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. - const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? - ids[it*si1 + iex] : INT_MAX; - const int iex_used = expert_used == expert ? iex : -1; - nex_prev += expert_used < expert; - - // Whether the threads at this token position have used the expert: - const int it_compact_add_self = warp_reduce_any(iex_used != -1); - - // Do a scan over threads at lower token positions in warp to get the correct index for writing data: - int it_compact_add_lower = 0; -#pragma unroll - for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { - const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); - if (threadIdx.x >= static_cast(offset)) { - it_compact_add_lower += tmp; - } - } - - if (iex_used != -1) { - store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); - } - - // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: - it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); - } - } - nex_prev = warp_reduce_sum(nex_prev); - - for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { - const mmq_ids_helper_store store_it = store[itc]; - const int it = store_it.it(); - const int iex_used = store_it.iex_used(); - ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; - ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; - } - - if (threadIdx.x != 0) { - return; - } - - expert_bounds[expert] = nex_prev; - - if (expert < static_cast(gridDim.x) - 1) { - return; - } - - expert_bounds[gridDim.x] = nex_prev + it_compact; -} - -template -static void launch_mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { - GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); - GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); - - const int id = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[id].warp_size; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); - - const dim3 num_blocks(n_experts, 1, 1); - const dim3 block_size(warp_size, 1, 1); - const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); - GGML_ASSERT(nbytes_shared <= smpbo); - mmq_ids_helper<<>> - (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); -} +#include "mmid.cuh" static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { @@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q( const int si1 = ids->nb[1] / ggml_element_size(ids); const int sis1 = nb12 / nb11; - switch (n_expert_used) { - case 2: - launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 4: - launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 6: - launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 8: - launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 16: - launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 32: - launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - default: - launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - } + ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 5b21ef05b3c35..57ab839393aa0 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -7,14 +7,14 @@ template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int row = blockIdx.x; const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; const int tid = threadIdx.x; @@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x*tmpy.x; - sumf[j] += tmpx.y*tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else if constexpr (std::is_same_v) { @@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x * tmpy.x; - sumf[j] += tmpx.y * tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else { @@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f( #endif // FP16_AVAILABLE } } else if constexpr (std::is_same_v) { +//TODO: add support for ggml_cuda_mad for hip_bfloat162 +#if defined(GGML_USE_HIP) const int * x2 = (const int *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); } } +#else + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const nv_bfloat162 tmpx = x2[col2]; +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + } + } +#endif } else { static_assert(std::is_same_v, "unsupported type"); } @@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda( GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda( case 32: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c3fe8f4e91002..553cf8f5f39ac 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -7,6 +7,8 @@ #include +#include + #ifndef TARGET_OS_VISION #define TARGET_OS_VISION 0 #endif @@ -22,6 +24,9 @@ // overload of MTLGPUFamilyMetal3 (not available in some environments) static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; +// virtual address for GPU memory allocations +static atomic_uintptr_t g_addr_device = 0x000000400ULL; + #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -827,7 +832,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; struct ggml_metal_buffer { - void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985 + void * all_data; size_t all_size; // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host @@ -965,14 +970,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, if (shared) { res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; - res->owned = true; } else { - // dummy, non-NULL value - we'll populate this after creating the Metal buffer below - res->all_data = (void *) 0x000000400ULL; + // use virtual address from g_addr_device counter + res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; + res->owned = true; + res->device = ggml_metal_device_get_obj(dev); res->queue = ggml_metal_device_get_queue(dev); @@ -983,15 +989,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->buffers[0].metal = nil; if (size_aligned > 0) { - if (props_dev->use_shared_buffers &&shared) { + if (props_dev->use_shared_buffers && shared) { res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; } else { res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; - - res->all_data = (void *) (res->buffers[0].metal.gpuAddress); } } @@ -1139,7 +1143,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) { void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { if (buf->is_shared) { - memset((char *)tensor->data + offset, value, size); + memset((char *) tensor->data + offset, value, size); return; } @@ -1168,7 +1172,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { if (buf->is_shared) { - memcpy((char *)tensor->data + offset, data, size); + memcpy((char *) tensor->data + offset, data, size); return; } @@ -1223,7 +1227,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { if (buf->is_shared) { - memcpy(data, (const char *)tensor->data + offset, size); + memcpy(data, (const char *) tensor->data + offset, size); return; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a448c14f66b63..fa2d82cefb40e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -251,6 +251,7 @@ typedef struct { int32_t sect_1; int32_t sect_2; int32_t sect_3; + bool src2; } ggml_metal_kargs_rope; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a61ea8fb5a7b3..784b7b77851e6 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2969,6 +2969,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* sect_1 =*/ sect_1, /* sect_2 =*/ sect_2, /* sect_3 =*/ sect_3, + /* src2 =*/ op->src[2] != nullptr, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 1029cf8f9a3ab..6d39ddcc634ef 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3748,7 +3748,7 @@ kernel void kernel_rope_norm( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3801,7 +3801,7 @@ kernel void kernel_rope_neox( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3872,7 +3872,7 @@ kernel void kernel_rope_multi( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3939,7 +3939,7 @@ kernel void kernel_rope_vision( const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); // end of mrope - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d2759069b3e29..0693d38d80af6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2686,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx // if rms_norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && - !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 83a83887b5180..de01336cd3fd2 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,9 +1,18 @@ cmake_minimum_required(VERSION 3.19) cmake_policy(SET CMP0114 NEW) cmake_policy(SET CMP0116 NEW) +if (POLICY CMP0147) + # Parallel build custom build steps + cmake_policy(SET CMP0147 NEW) +endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # Parallel build object files + add_definitions(/MP) +endif() + function(detect_host_compiler) if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3cd89c711650d..1674dc66ab912 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ } + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) @@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) @@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); - const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); - const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + // For F32, the shader treats it as a block of size 4 (for vec4 loads) + if (k->type == GGML_TYPE_F32) { + k_stride /= 4; + } + if (v->type == GGML_TYPE_F32) { + v_stride /= 4; + } uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); bool aligned = (KV % alignment) == 0 && @@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } switch (op->src[1]->type) { case GGML_TYPE_F16: + case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: // supported in scalar and coopmat2 paths diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 6a5bb4574d713..67baedf7c6147 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,6 +1,18 @@ #include "types.glsl" +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { + vec4 block; +}; + +float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const vec4 v = bl.block; + const uint idx = coordInBlock[1]; + const f16vec4 vf16 = f16vec4(v); + return vf16[idx]; +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_F32) +#define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9b1f153bf7f19..eb93903c4681e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; -#if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 +#if defined(DATA_A_F32) +layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; +#elif defined(A_TYPE_PACKED16) layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(DATA_A_F32) +#undef BLOCK_SIZE +#define BLOCK_SIZE 4 +#define BLOCK_BYTE_SIZE 16 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + // iqs is currently always zero in the flash attention shaders + if (binding_idx == BINDING_IDX_K) { + return k_packed.k_data_packed[a_offset + ib]; + } else { + return v_packed.v_data_packed[a_offset + ib]; + } +} +#endif + #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 85400ac5fc343..a20788c4b51e3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -313,12 +313,12 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE sums[WMITER * TM * WNITER * TN]; + ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b[TN]; + FLOAT_TYPE_VEC2 cache_b; - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = ACC_TYPE(0.0f); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); } #endif @@ -360,20 +360,22 @@ void main() { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); } } } } + } #endif @@ -388,8 +390,9 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX); + sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif #endif @@ -463,14 +466,21 @@ void main() { const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID - [[unroll]] for (uint cr = 0; cr < TM; cr++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; #ifdef MUL_MAT_ID - if (dr_warp + cr < p.M) { - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #else - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #endif // MUL_MAT_ID } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f0cc24ff31e1e..184f3f3a7db51 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -611,9 +611,6 @@ void process_shaders() { } for (const auto& tname : type_names) { - if (tname == "f32") { - continue; - } if (tname == "bf16") continue; #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -630,7 +627,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); @@ -639,7 +636,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3f404f5f1a032..14472dcf124fb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6911,7 +6911,7 @@ static std::vector> make_test_cases_perf() { } // qwen3-30b-a3b - for (int bs : {1, 4, 8, 32, 64, 128, 512}) { + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); @@ -6919,6 +6919,15 @@ static std::vector> make_test_cases_perf() { } } + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { + for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); + } + } + } + + // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { From 63f9f6cf435aa5d0a262f126930aff8ad2cb51db Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 16:47:12 -0500 Subject: [PATCH 3/8] use F32 accumulators for GLM4V_MOE --- src/llama-graph.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f29a1e98c9103..ffc2187a1b107 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -817,7 +817,7 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_GLM4V_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } @@ -1583,7 +1583,7 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_GLM4V_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } From 65603e1c3878a93a1851475fc0957d464c685c93 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 16:47:31 -0500 Subject: [PATCH 4/8] update notes --- convert_hf_to_gguf.py | 62 ++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 36278866da005..ea1a98801cc01 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9221,30 +9221,44 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Glm4vMoeForConditionalGeneration") class GLM4V_MoE(MmprojModel): - # - # the HF model's type is `glm4v_moe`. internally, it consists of two models: - # - `glm4v_moe_text` - # + main text model - # + tensor names start with "model.language_model." - # + "2D-RoPE" (aKa Roformer) w/ embeddings dynamically adapted via bicubic interpolation - # - `glm4v_moe` - # + vision adapter (ViT) - # + tensor names start with "model.visual." - # + "3D-RoPE" (without the interpolation mentioned above) - # - # other notable quirks include: - # - has MTP layer (need to keep these tensors - same as GLM-4.5-Air) - # - RoPE theta value (θ): use 10k rather than 100k for GLM-4.5-Air - # - the model's vision supports video input, but this is not implemented here - # - # for more info, refer to: - # - reference impl : https://github.com/huggingface/transformers/tree/main/src/transformers/models/glm4v_moe - # - HF model card : https://huggingface.co/zai-org/GLM-4.5V - # - arXiv paper (model) : https://arxiv.org/abs/2507.01006 - # - arXiv paper (orig. ViT) : https://arxiv.org/abs/2411.14402 - # - # TODO: the model's tokenizer has video-related special tokens - deal with these (??) - # + """The HF architecture is called **`Glm4vMoeForConditionalGeneration`** (`"model_type": "glm4v_moe"`). Internally, this consists of an LLM (text model) and a ViT (vision adapter / multimodal projector): + + ### LLM (text model `glm4v_moe_text`) + - Based on GLM-4.5-Air + - Tensor names start with `model.language_model.` + - Uses a "multimodal 3D RoPE" - in `apply_multimodal_rotary_pos_emb`, it applies rotary embeddings across temporal, height, and width dimensions for visual tokens + + ### ViT (vision adapter `glm4v_moe`) + - Adapted from [apple/aimv2-huge-patch14-336](https://huggingface.co/apple/aimv2-huge-patch14-336): + + Architecture **`Aimv2VisionModel`** + + ~681M params + + 24 layers + + hidden_size (n_embd): 1536 + + intermediate_size (n_ff): 4096 + + image_size: 336 + + patch_size: 14 + + num_channels: 3 + + depth: 24 + - Tensor names start with `model.visual.` + - Its 2D positional embeddings are dynamically adapted via bicubic interpolation within the `Glm4vMoeVisionEmbeddings` module to handle varied image resolutions + - It also applies its own rotary position embeddings within the self-attention blocks (via `apply_rotary_pos_emb_vision`) + + ## Other notes: + - Native context length is `65_536` (as opposed to `131_072` for GLM-4.5-Air) + - RoPE theta (θ): `10_000.0` (as opposed to `100_000.0` for GLM-4.5-Air) + - The model supports video input, but I currently do not plan to support video input in this PR + - Tokenizer has video-related special tokens - need to handle these during conversion + + ### References: + - The HF reference implementations: + + [modeling_glm4v_moe.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py) + + [modular_glm4v_moe.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modular_glm4v_moe.py) + - The HF [model card](https://huggingface.co/zai-org/GLM-4.5V) + - The HF [config.json](https://huggingface.co/zai-org/GLM-4.5V/blob/main/config.json) + + ### See also: + - [arXiv:2507.01006](https://arxiv.org/abs/2507.01006) + - [arXiv:2411.14402](https://arxiv.org/abs/2411.14402)""" pass From 7eaefc3be069cc444ca9f738ca13909cc457eb31 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 16:48:21 -0500 Subject: [PATCH 5/8] add arch --- gguf-py/gguf/constants.py | 34 ++++++++++++++++++++++++++++++---- src/llama-arch.cpp | 28 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f5e5fba8008bd..0afc58331b565 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum): CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() + GLM4V_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -656,10 +657,10 @@ class MODEL_TENSOR(IntEnum): A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() # nextn/mtp - NEXTN_EH_PROJ = auto() - NEXTN_EMBED_TOKENS = auto() - NEXTN_ENORM = auto() - NEXTN_HNORM = auto() + NEXTN_EH_PROJ = auto() + NEXTN_EMBED_TOKENS = auto() + NEXTN_ENORM = auto() + NEXTN_HNORM = auto() NEXTN_SHARED_HEAD_HEAD = auto() NEXTN_SHARED_HEAD_NORM = auto() @@ -729,6 +730,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", + MODEL_ARCH.GLM4V_MOE: "glm4v_moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -2273,6 +2275,30 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.GLM4V_MOE: [ # same as GLM4_MOE without MTP tensors + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 869e4dccf0dc9..1f6ab7618bbb3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM4V_MOE, "glm4v_moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1502,6 +1503,33 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, }, }, + { + LLM_ARCH_GLMV4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_BITNET, { From 913280685e15f4b9374cd0cece55fdb0c4dfc821 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 17:10:55 -0500 Subject: [PATCH 6/8] llama-model : add placeholders --- src/llama-model.cpp | 20 ++++++++++++++++++++ src/llama-model.h | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0cdad9babd9b2..3ba4033ac04b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1610,6 +1610,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4V_MOE: + { + // TODO + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4891,6 +4895,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM4V_MOE: + { + // TODO + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -14682,6 +14691,12 @@ struct llm_build_glm4_moe : public llm_graph_context { } }; +struct llm_build_glm4v_moe : public llm_graph_context { + llm_build_glm4v_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // TODO + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -19749,6 +19764,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4V_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -20117,6 +20136,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_GLM4V_MOE: return LLAMA_ROPE_TYPE_MROPE; // all model arches should be listed explicitly here diff --git a/src/llama-model.h b/src/llama-model.h index 7f48662f2807a..2c9b05fbc790f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -110,7 +110,7 @@ enum llm_type { LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, - LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_106B_A12B, // GLM-4.5-Air (and GLM-4.5V text model) LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_355B_A32B, // GLM-4.5 From f88780f4430e84875ca572da18e427663c94b3c5 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 17:13:29 -0500 Subject: [PATCH 7/8] fix arch name for tensor names --- src/llama-arch.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 1f6ab7618bbb3..69ea3f3e3ecf0 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1504,7 +1504,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_GLMV4_MOE, + LLM_ARCH_GLM4V_MOE, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, From 128d85079732e2c0af06de4a0e8a922074892809 Mon Sep 17 00:00:00 2001 From: ddh0 Date: Tue, 14 Oct 2025 21:28:27 -0500 Subject: [PATCH 8/8] oops --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ea1a98801cc01..f4f660bd3773b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9246,7 +9246,7 @@ class GLM4V_MoE(MmprojModel): ## Other notes: - Native context length is `65_536` (as opposed to `131_072` for GLM-4.5-Air) - RoPE theta (θ): `10_000.0` (as opposed to `100_000.0` for GLM-4.5-Air) - - The model supports video input, but I currently do not plan to support video input in this PR + - The model supports video input, but this is not yet implemented (only images) - Tokenizer has video-related special tokens - need to handle these during conversion ### References: