Skip to content

Commit 5d9ae95

Browse files
committed
musa: extract ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent cd42f3e commit 5d9ae95

File tree

5 files changed

+160
-31
lines changed

5 files changed

+160
-31
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
#include "ggml-cuda/upscale.cuh"
4141
#include "ggml-cuda/wkv.cuh"
4242
#include "ggml-cuda/gla.cuh"
43+
#ifdef GGML_USE_MUSA
44+
#include "ggml-musa/mublas.cuh"
45+
#endif // GGML_USE_MUSA
4346
#include "ggml.h"
4447

4548
#include <algorithm>
@@ -1745,6 +1748,52 @@ static __global__ void k_compute_batched_ptrs(
17451748
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
17461749
}
17471750

1751+
#ifndef GGML_USE_MUSA
1752+
static void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex(
1753+
ggml_backend_cuda_context & ctx,
1754+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
1755+
const half * src0_f16, const half * src1_f16, char * dst_t,
1756+
const size_t nbd2, const size_t nbd3,
1757+
const int64_t r2, const int64_t r3,
1758+
const int64_t s11, const int64_t s12, const int64_t s13,
1759+
const void * alpha, const void * beta,
1760+
const cudaDataType_t cu_data_type,
1761+
const cublasComputeType_t cu_compute_type,
1762+
cudaStream_t main_stream
1763+
) {
1764+
GGML_TENSOR_BINARY_OP_LOCALS
1765+
1766+
// use cublasGemmBatchedEx
1767+
const int64_t ne23 = ne12*ne13;
1768+
1769+
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1770+
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1771+
1772+
dim3 block_dims(ne13, ne12);
1773+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1774+
src0_f16, src1_f16, dst_t,
1775+
ptrs_src.get(), ptrs_dst.get(),
1776+
ne12, ne13,
1777+
ne23,
1778+
nb02, nb03,
1779+
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1780+
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1781+
nbd2, nbd3,
1782+
r2, r3);
1783+
CUDA_CHECK(cudaGetLastError());
1784+
1785+
CUBLAS_CHECK(
1786+
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1787+
ne01, ne11, ne10,
1788+
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1789+
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1790+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1791+
ne23,
1792+
cu_compute_type,
1793+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1794+
}
1795+
#endif // GGML_USE_MUSA
1796+
17481797
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
17491798
GGML_ASSERT(!ggml_is_transposed(src0));
17501799
GGML_ASSERT(!ggml_is_transposed(src1));
@@ -1872,34 +1921,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18721921
cu_compute_type,
18731922
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
18741923
} else {
1875-
// use cublasGemmBatchedEx
1876-
const int64_t ne23 = ne12*ne13;
1877-
1878-
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1879-
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1880-
1881-
dim3 block_dims(ne13, ne12);
1882-
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1883-
src0_f16, src1_f16, dst_t,
1884-
ptrs_src.get(), ptrs_dst.get(),
1885-
ne12, ne13,
1886-
ne23,
1887-
nb02, nb03,
1888-
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1889-
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1890-
nbd2, nbd3,
1891-
r2, r3);
1892-
CUDA_CHECK(cudaGetLastError());
1893-
1894-
CUBLAS_CHECK(
1895-
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1896-
ne01, ne11, ne10,
1897-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1898-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1899-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1900-
ne23,
1901-
cu_compute_type,
1902-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1924+
ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex(
1925+
ctx,
1926+
src0, src1, dst,
1927+
src0_f16, src1_f16, dst_t,
1928+
nbd2, nbd3,
1929+
r2, r3,
1930+
s11, s12, s13,
1931+
alpha, beta,
1932+
cu_data_type, cu_compute_type,
1933+
main_stream);
19031934
}
19041935
#endif
19051936

@@ -3018,6 +3049,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30183049
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
30193050
return false;
30203051
}
3052+
if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3053+
a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3054+
return false;
3055+
}
30213056
}
30223057
#endif // GGML_USE_MUSA
30233058
switch (a->type) {

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ if (MUSAToolkit_FOUND)
2727

2828
file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
2929
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30-
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
30+
file(GLOB HRDS "../ggml-musa/*.cuh")
31+
list(APPEND GGML_HEADERS_MUSA ${HRDS})
3132

3233
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
3334
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")

ggml/src/ggml-musa/mublas.cu

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "mublas.cuh"
2+
3+
static __global__ void k_compute_batched_ptrs(
4+
const half * src0_as_f16, const half * src1_as_f16, char * dst,
5+
const void ** ptrs_src, void ** ptrs_dst,
6+
int64_t ne12, int64_t ne13,
7+
int64_t ne23,
8+
size_t nb02, size_t nb03,
9+
size_t nb12, size_t nb13,
10+
size_t nbd2, size_t nbd3,
11+
int64_t r2, int64_t r3) {
12+
const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
13+
const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
14+
15+
if (i13 >= ne13 || i12 >= ne12) {
16+
return;
17+
}
18+
19+
const int64_t i03 = i13 / r3;
20+
const int64_t i02 = i12 / r2;
21+
22+
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
23+
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
24+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
25+
}
26+
27+
void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex(
28+
ggml_backend_cuda_context & ctx,
29+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
30+
const half * src0_f16, const half * src1_f16, char * dst_t,
31+
const size_t nbd2, const size_t nbd3,
32+
const int64_t r2, const int64_t r3,
33+
const int64_t s11, const int64_t s12, const int64_t s13,
34+
const void * alpha, const void * beta,
35+
const cudaDataType_t cu_data_type,
36+
const cublasComputeType_t cu_compute_type,
37+
cudaStream_t main_stream
38+
) {
39+
GGML_TENSOR_BINARY_OP_LOCALS
40+
41+
// use cublasGemmBatchedEx
42+
const int64_t ne23 = ne12*ne13;
43+
44+
// Allocate memory for pointer arrays using cudaMalloc to avoid segmentation faults in muBLAS.
45+
const void ** ptrs_src;
46+
void ** ptrs_dst;
47+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
48+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
49+
50+
dim3 block_dims(ne13, ne12);
51+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
52+
src0_f16, src1_f16, dst_t,
53+
ptrs_src, ptrs_dst,
54+
ne12, ne13,
55+
ne23,
56+
nb02, nb03,
57+
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
58+
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
59+
nbd2, nbd3,
60+
r2, r3);
61+
CUDA_CHECK(cudaGetLastError());
62+
63+
// This operation is essential for musa; without it, generated tokens will
64+
// be garbled and may eventually cause MUBLAS_STATUS_INTERNAL_ERROR.
65+
CUDA_CHECK(cudaDeviceSynchronize());
66+
67+
CUBLAS_CHECK(
68+
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
69+
ne01, ne11, ne10,
70+
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/nb00,
71+
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, s11,
72+
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne0,
73+
ne23,
74+
cu_compute_type,
75+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
76+
77+
CUDA_CHECK(cudaFree(ptrs_src));
78+
CUDA_CHECK(cudaFree(ptrs_dst));
79+
}

ggml/src/ggml-musa/mublas.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include "ggml-cuda/common.cuh"
2+
3+
void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex(
4+
ggml_backend_cuda_context & ctx,
5+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6+
const half * src0_f16, const half * src1_f16, char * dst_t,
7+
const size_t nbd2, const size_t nbd3,
8+
const int64_t r2, const int64_t r3,
9+
const int64_t s11, const int64_t s12, const int64_t s13,
10+
const void * alpha, const void * beta,
11+
const cudaDataType_t cu_data_type,
12+
const cublasComputeType_t cu_compute_type,
13+
cudaStream_t main_stream
14+
);

ggml/src/ggml-musa/mudnn.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include "../include/ggml.h"
4-
#include "../ggml-cuda/common.cuh"
3+
#include "ggml-cuda/common.cuh"
4+
#include "ggml.h"
55

66
// Asynchronously copies data from src tensor to dst tensor using the provided context.
77
// Returns a musaError_t indicating success or failure.

0 commit comments

Comments
 (0)