@@ -1750,7 +1750,7 @@ static void ggml_cuda_op_mul_mat(
17501750}
17511751
17521752static __global__ void k_compute_batched_ptrs (
1753- const void * src0_as_f16, const void * src1_as_f16, char * dst,
1753+ const half * src0_as_f16, const half * src1_as_f16, char * dst,
17541754 const void ** ptrs_src, void ** ptrs_dst,
17551755 int64_t ne12, int64_t ne13,
17561756 int64_t ne23,
@@ -1773,139 +1773,91 @@ static __global__ void k_compute_batched_ptrs(
17731773 ptrs_dst[0 *ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
17741774}
17751775
1776- // Type traits for mapping ggml types to CUDA/cuBLAS types
1777- template <ggml_type T>
1778- struct batched_mul_mat_traits ;
1779-
1780- template <>
1781- struct batched_mul_mat_traits <GGML_TYPE_F32> {
1782- using cuda_type = float ;
1783- static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1784- static inline const cudaDataType_t data_type = CUDA_R_32F;
1785- static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1786- static inline const float alpha = 1 .0f ;
1787- static inline const float beta = 0 .0f ;
1788- static inline const void * get_alpha () { static const float val = alpha; return &val; }
1789- static inline const void * get_beta () { static const float val = beta; return &val; }
1790- static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_fp32_nc_cuda (src_type); }
1791- };
1792-
1793- template <>
1794- struct batched_mul_mat_traits <GGML_TYPE_BF16> {
1795- using cuda_type = nv_bfloat16;
1796- static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1797- static inline const cudaDataType_t data_type = CUDA_R_16BF;
1798- static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1799- static inline const float alpha = 1 .0f ;
1800- static inline const float beta = 0 .0f ;
1801- static inline const void * get_alpha () { static const float val = alpha; return &val; }
1802- static inline const void * get_beta () { static const float val = beta; return &val; }
1803- static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_bf16_nc_cuda (src_type); }
1804- };
1805-
1806- template <>
1807- struct batched_mul_mat_traits <GGML_TYPE_F16> {
1808- using cuda_type = half;
1809- static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1810- static inline const cudaDataType_t data_type = CUDA_R_16F;
1811- static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1812- static inline const half alpha = 1.0 ;
1813- static inline const half beta = 0.0 ;
1814- static inline const void * get_alpha () { static const half val = alpha; return &val; }
1815- static inline const void * get_beta () { static const half val = beta; return &val; }
1816- static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_fp16_nc_cuda (src_type); }
1817- };
1818-
1819- template <ggml_type src0_type>
1820- static void ggml_cuda_mul_mat_batched_cublas_impl (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1821- using traits = batched_mul_mat_traits<src0_type>;
1822- using cuda_t = typename traits::cuda_type;
1823-
1776+ static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
18241777 GGML_ASSERT (!ggml_is_transposed (src0));
18251778 GGML_ASSERT (!ggml_is_transposed (src1));
1779+
18261780 GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ));
1827- GGML_ASSERT (src0->type == src0_type);
1828- GGML_ASSERT (ggml_is_contiguous (dst));
1781+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
18291782
18301783 // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
18311784 // As long as dst is contiguous this does not matter though.
1785+ GGML_ASSERT (ggml_is_contiguous (dst));
18321786
18331787 GGML_TENSOR_BINARY_OP_LOCALS
18341788
18351789 const int64_t ne_dst = ggml_nelements (dst);
1790+
18361791 cudaStream_t main_stream = ctx.stream ();
1792+
18371793 CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (), main_stream));
18381794
1795+ const half * src0_f16 = (const half *) src0->data ;
18391796 float * dst_ddf = (float *) dst->data ;
1797+
1798+ const half * src1_f16 = (const half *) src1->data ;
18401799 const size_t ts_src1 = ggml_type_size (src1->type );
18411800 GGML_ASSERT (nb10 == ts_src1);
18421801 int64_t s11 = nb11 / ts_src1;
18431802 int64_t s12 = nb12 / ts_src1;
18441803 int64_t s13 = nb13 / ts_src1;
1804+ ggml_cuda_pool_alloc<half> src1_f16_alloc (ctx.pool ());
18451805
1846- const cuda_t * src0_ptr = nullptr ;
1847- const cuda_t * src1_ptr = nullptr ;
1848-
1849- ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
1850- ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
1851-
1852- // Handle src0
1853- src0_ptr = (const cuda_t *) src0->data ;
1854-
1855- // Handle src1 - convert if necessary
1856- if (src1->type == src0_type) {
1857- src1_ptr = (const cuda_t *) src1->data ;
1858- } else {
1859- // Convert src1 to target type using traits conversion functions
1806+ // convert src1 to fp16
1807+ if (src1->type != GGML_TYPE_F16) {
1808+ const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
18601809 const int64_t ne_src1 = ggml_nelements (src1);
1861- src1_alloc.alloc (ne_src1);
1810+ src1_f16_alloc.alloc (ne_src1);
1811+ GGML_ASSERT (to_fp16_cuda != nullptr );
18621812
1863- const auto convert_func = traits::get_nc_converter (src1->type );
1864- GGML_ASSERT (convert_func != nullptr );
1865- convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1866- src1_ptr = src1_alloc.get ();
1813+ to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1814+
1815+ src1_f16 = src1_f16_alloc.get ();
18671816 s11 = ne10;
18681817 s12 = ne11*s11;
18691818 s13 = ne12*s12;
18701819 }
18711820
1872- // Setup destination buffer
1873- ggml_cuda_pool_alloc<cuda_t > dst_temp (ctx.pool ());
1821+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool ());
18741822 char * dst_t ;
1823+
1824+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1825+ cudaDataType_t cu_data_type = CUDA_R_16F;
1826+
1827+ // dst strides
18751828 size_t nbd2 = dst->nb [2 ];
18761829 size_t nbd3 = dst->nb [3 ];
18771830
1878- cublasComputeType_t cu_compute_type = traits::compute_type;
1879- cudaDataType_t cu_data_type = traits::data_type;
1880- cudaDataType_t cu_data_type_a = traits::data_type;
1881- cudaDataType_t cu_data_type_b = traits::data_type;
1882- const void * alpha = traits::get_alpha ();
1883- const void * beta = traits::get_beta ();
1831+ const half alpha_f16 = 1 .0f ;
1832+ const half beta_f16 = 0 .0f ;
1833+
18841834 const float alpha_f32 = 1 .0f ;
1885- const float beta_f32 = 0 .0f ;
1835+ const float beta_f32 = 0 .0f ;
1836+
1837+ const void * alpha = &alpha_f16;
1838+ const void * beta = &beta_f16;
18861839
18871840 if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1888- if constexpr (src0_type == GGML_TYPE_F32) {
1889- dst_t = (char *) dst_ddf; // Direct F32 output
1890- } else {
1891- dst_t = (char *) dst_temp.alloc (ne_dst);
1892- nbd2 /= sizeof (float ) / sizeof (cuda_t );
1893- nbd3 /= sizeof (float ) / sizeof (cuda_t );
1894- }
1841+ dst_t = (char *) dst_f16.alloc (ne_dst);
1842+
1843+ nbd2 /= sizeof (float ) / sizeof (half);
1844+ nbd3 /= sizeof (float ) / sizeof (half);
18951845 } else {
18961846 dst_t = (char *) dst_ddf;
1847+
18971848 cu_compute_type = CUBLAS_COMPUTE_32F;
1898- cu_data_type = CUDA_R_32F;
1849+ cu_data_type = CUDA_R_32F;
1850+
18991851 alpha = &alpha_f32;
1900- beta = &beta_f32;
1852+ beta = &beta_f32;
19011853 }
19021854
19031855 int id = ggml_cuda_get_device ();
19041856 const int cc = ggml_cuda_info ().devices [id].cc ;
19051857 if (GGML_CUDA_CC_IS_CDNA (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
19061858 cu_compute_type = CUBLAS_COMPUTE_32F;
19071859 alpha = &alpha_f32;
1908- beta = &beta_f32;
1860+ beta = &beta_f32;
19091861 }
19101862
19111863 GGML_ASSERT (ne12 % ne02 == 0 );
@@ -1915,15 +1867,35 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19151867 const int64_t r2 = ne12/ne02;
19161868 const int64_t r3 = ne13/ne03;
19171869
1870+ #if 0
1871+ // use cublasGemmEx
1872+ {
1873+ for (int i13 = 0; i13 < ne13; ++i13) {
1874+ for (int i12 = 0; i12 < ne12; ++i12) {
1875+ int i03 = i13 / r3;
1876+ int i02 = i12 / r2;
1877+
1878+ CUBLAS_CHECK(
1879+ cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1880+ ne01, ne11, ne10,
1881+ alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1882+ src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1883+ beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1884+ cu_compute_type,
1885+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1886+ }
1887+ }
1888+ }
1889+ #else
19181890 if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
19191891 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
19201892 // use cublasGemmStridedBatchedEx
19211893 CUBLAS_CHECK (
19221894 cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19231895 ne01, ne11, ne10,
1924- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1925- src1_ptr, cu_data_type_b, s11, s12, // strideB
1926- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1896+ alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1897+ src1_f16, CUDA_R_16F, s11, s12, // strideB
1898+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
19271899 ne12*ne13,
19281900 cu_compute_type,
19291901 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1934,55 +1906,34 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19341906 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
19351907 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
19361908
1937- size_t src1_stride_size = sizeof (cuda_t );
1938-
19391909 dim3 block_dims (ne13, ne12);
19401910 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1941- src0_ptr, src1_ptr , dst_t ,
1911+ src0_f16, src1_f16 , dst_t ,
19421912 ptrs_src.get (), ptrs_dst.get (),
19431913 ne12, ne13,
19441914 ne23,
19451915 nb02, nb03,
1946- ( src1->type == src0_type) ? nb12 : s12*src1_stride_size ,
1947- ( src1->type == src0_type) ? nb13 : s13*src1_stride_size ,
1916+ src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof (half) ,
1917+ src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof (half) ,
19481918 nbd2, nbd3,
19491919 r2, r3);
1950-
19511920 CUDA_CHECK (cudaGetLastError ());
19521921
19531922 CUBLAS_CHECK (
19541923 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19551924 ne01, ne11, ne10,
1956- alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
1957- (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
1958- beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1925+ alpha, (const void **) (ptrs_src.get () + 0 *ne23), CUDA_R_16F, nb01/nb00,
1926+ (const void **) (ptrs_src.get () + 1 *ne23), CUDA_R_16F, s11,
1927+ beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
19591928 ne23,
19601929 cu_compute_type,
19611930 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19621931 }
1932+ #endif
19631933
1964- // Convert output back to F32 if needed
1965- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1966- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (traits::ggml_type_val);
1967- to_fp32_cuda (dst_temp.get (), dst_ddf, ne_dst, main_stream);
1968- }
1969- }
1970-
1971- static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1972- GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1973-
1974- switch (src0->type ) {
1975- case GGML_TYPE_F32:
1976- ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1977- break ;
1978- case GGML_TYPE_BF16:
1979- ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1980- break ;
1981- case GGML_TYPE_F16:
1982- ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1983- break ;
1984- default :
1985- GGML_ABORT (" Unsupported type" );
1934+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1935+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1936+ to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
19861937 }
19871938}
19881939
@@ -2034,12 +1985,6 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20341985 // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20351986 // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20361987
2037- // TODO update for generic tensor parallelism
2038- const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2039- bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2040- bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2041- bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2042-
20431988 if (!split && use_mul_mat_vec) {
20441989 // the custom F16 vector kernel can be used over batched cuBLAS GEMM
20451990 // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -2048,8 +1993,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20481993 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
20491994 } else if (!split && use_mul_mat_q) {
20501995 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
2051- } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2052- && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
1996+ } else if (!split && src0-> type == GGML_TYPE_F16 && (src1-> type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1997+ !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
20531998 // general KQ + KQV multi-batch without FlashAttention
20541999 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
20552000 } else if (use_mul_mat_vec) {
0 commit comments