|
40 | 40 | #include "ggml-cuda/upscale.cuh" |
41 | 41 | #include "ggml-cuda/wkv.cuh" |
42 | 42 | #include "ggml-cuda/gla.cuh" |
| 43 | +#ifdef GGML_USE_MUSA |
| 44 | +#include "ggml-musa/mublas.cuh" |
| 45 | +#endif // GGML_USE_MUSA |
43 | 46 | #include "ggml.h" |
44 | 47 |
|
45 | 48 | #include <algorithm> |
@@ -1745,6 +1748,52 @@ static __global__ void k_compute_batched_ptrs( |
1745 | 1748 | ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; |
1746 | 1749 | } |
1747 | 1750 |
|
| 1751 | +#ifndef GGML_USE_MUSA |
| 1752 | +static void ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex( |
| 1753 | + ggml_backend_cuda_context & ctx, |
| 1754 | + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, |
| 1755 | + const half * src0_f16, const half * src1_f16, char * dst_t, |
| 1756 | + const size_t nbd2, const size_t nbd3, |
| 1757 | + const int64_t r2, const int64_t r3, |
| 1758 | + const int64_t s11, const int64_t s12, const int64_t s13, |
| 1759 | + const void * alpha, const void * beta, |
| 1760 | + const cudaDataType_t cu_data_type, |
| 1761 | + const cublasComputeType_t cu_compute_type, |
| 1762 | + cudaStream_t main_stream |
| 1763 | +) { |
| 1764 | + GGML_TENSOR_BINARY_OP_LOCALS |
| 1765 | + |
| 1766 | + // use cublasGemmBatchedEx |
| 1767 | + const int64_t ne23 = ne12*ne13; |
| 1768 | + |
| 1769 | + ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); |
| 1770 | + ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); |
| 1771 | + |
| 1772 | + dim3 block_dims(ne13, ne12); |
| 1773 | + k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( |
| 1774 | + src0_f16, src1_f16, dst_t, |
| 1775 | + ptrs_src.get(), ptrs_dst.get(), |
| 1776 | + ne12, ne13, |
| 1777 | + ne23, |
| 1778 | + nb02, nb03, |
| 1779 | + src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), |
| 1780 | + src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), |
| 1781 | + nbd2, nbd3, |
| 1782 | + r2, r3); |
| 1783 | + CUDA_CHECK(cudaGetLastError()); |
| 1784 | + |
| 1785 | + CUBLAS_CHECK( |
| 1786 | + cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, |
| 1787 | + ne01, ne11, ne10, |
| 1788 | + alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, |
| 1789 | + (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, |
| 1790 | + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, |
| 1791 | + ne23, |
| 1792 | + cu_compute_type, |
| 1793 | + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
| 1794 | +} |
| 1795 | +#endif // GGML_USE_MUSA |
| 1796 | + |
1748 | 1797 | static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
1749 | 1798 | GGML_ASSERT(!ggml_is_transposed(src0)); |
1750 | 1799 | GGML_ASSERT(!ggml_is_transposed(src1)); |
@@ -1872,34 +1921,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co |
1872 | 1921 | cu_compute_type, |
1873 | 1922 | CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
1874 | 1923 | } else { |
1875 | | - // use cublasGemmBatchedEx |
1876 | | - const int64_t ne23 = ne12*ne13; |
1877 | | - |
1878 | | - ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); |
1879 | | - ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); |
1880 | | - |
1881 | | - dim3 block_dims(ne13, ne12); |
1882 | | - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( |
1883 | | - src0_f16, src1_f16, dst_t, |
1884 | | - ptrs_src.get(), ptrs_dst.get(), |
1885 | | - ne12, ne13, |
1886 | | - ne23, |
1887 | | - nb02, nb03, |
1888 | | - src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), |
1889 | | - src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), |
1890 | | - nbd2, nbd3, |
1891 | | - r2, r3); |
1892 | | - CUDA_CHECK(cudaGetLastError()); |
1893 | | - |
1894 | | - CUBLAS_CHECK( |
1895 | | - cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, |
1896 | | - ne01, ne11, ne10, |
1897 | | - alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, |
1898 | | - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, |
1899 | | - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, |
1900 | | - ne23, |
1901 | | - cu_compute_type, |
1902 | | - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
| 1924 | + ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex( |
| 1925 | + ctx, |
| 1926 | + src0, src1, dst, |
| 1927 | + src0_f16, src1_f16, dst_t, |
| 1928 | + nbd2, nbd3, |
| 1929 | + r2, r3, |
| 1930 | + s11, s12, s13, |
| 1931 | + alpha, beta, |
| 1932 | + cu_data_type, cu_compute_type, |
| 1933 | + main_stream); |
1903 | 1934 | } |
1904 | 1935 | #endif |
1905 | 1936 |
|
@@ -3018,6 +3049,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g |
3018 | 3049 | a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) { |
3019 | 3050 | return false; |
3020 | 3051 | } |
| 3052 | + if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID && |
| 3053 | + a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) { |
| 3054 | + return false; |
| 3055 | + } |
3021 | 3056 | } |
3022 | 3057 | #endif // GGML_USE_MUSA |
3023 | 3058 | switch (a->type) { |
|
0 commit comments