Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return;
}

if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
return;
}
}

cudaStream_t stream = ctx.stream();
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under:
Expand Down
382 changes: 34 additions & 348 deletions ggml/src/ggml-cuda/mmf.cu

Large diffs are not rendered by default.

466 changes: 465 additions & 1 deletion ggml/src/ggml-cuda/mmf.cuh

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
DECL_MMQ_CASE({type});
"""

SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmf.cuh"

DECL_MMF_CASE({type});
"""


def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower()
Expand Down Expand Up @@ -76,3 +83,7 @@ def get_head_sizes(type_k, type_v):
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
f.write(SOURCE_MMQ.format(type=type))

for type in range(1, 17):
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
f.write(SOURCE_MMF.format(type=type))
2 changes: 1 addition & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6261,7 +6261,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int n_mats : {4, 8}) {
for (int n_used : {1, 2, 4}) {
for (bool b : {false, true}) {
for (int n : {1, 32, 129}) {
for (int n : {1, 4, 5, 32, 129}) {
int m = 512;
int k = 256;
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
Expand Down
Loading