@@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
17491749}
17501750
17511751static __global__ void k_compute_batched_ptrs (
1752- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1752+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
17531753 const void ** ptrs_src, void ** ptrs_dst,
17541754 int64_t ne12, int64_t ne13,
17551755 int64_t ne23,
@@ -1772,91 +1772,139 @@ static __global__ void k_compute_batched_ptrs(
17721772 ptrs_dst[0 *ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
17731773}
17741774
1775- static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1775+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1776+ template <ggml_type T>
1777+ struct batched_mul_mat_traits ;
1778+
1779+ template <>
1780+ struct batched_mul_mat_traits <GGML_TYPE_F32> {
1781+ using cuda_type = float ;
1782+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1784+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785+ static inline const float alpha = 1 .0f ;
1786+ static inline const float beta = 0 .0f ;
1787+ static inline const void * get_alpha () { static const float val = alpha; return &val; }
1788+ static inline const void * get_beta () { static const float val = beta; return &val; }
1789+ static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_fp32_nc_cuda (src_type); }
1790+ };
1791+
1792+ template <>
1793+ struct batched_mul_mat_traits <GGML_TYPE_BF16> {
1794+ using cuda_type = nv_bfloat16;
1795+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798+ static inline const float alpha = 1 .0f ;
1799+ static inline const float beta = 0 .0f ;
1800+ static inline const void * get_alpha () { static const float val = alpha; return &val; }
1801+ static inline const void * get_beta () { static const float val = beta; return &val; }
1802+ static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_bf16_nc_cuda (src_type); }
1803+ };
1804+
1805+ template <>
1806+ struct batched_mul_mat_traits <GGML_TYPE_F16> {
1807+ using cuda_type = half;
1808+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1810+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811+ static inline const half alpha = 1.0 ;
1812+ static inline const half beta = 0.0 ;
1813+ static inline const void * get_alpha () { static const half val = alpha; return &val; }
1814+ static inline const void * get_beta () { static const half val = beta; return &val; }
1815+ static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_fp16_nc_cuda (src_type); }
1816+ };
1817+
1818+ template <ggml_type src0_type>
1819+ static void ggml_cuda_mul_mat_batched_cublas_impl (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820+ using traits = batched_mul_mat_traits<src0_type>;
1821+ using cuda_t = typename traits::cuda_type;
1822+
17761823 GGML_ASSERT (!ggml_is_transposed (src0));
17771824 GGML_ASSERT (!ggml_is_transposed (src1));
1778-
17791825 GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ));
1780- GGML_ASSERT (src0->type == GGML_TYPE_F16);
1826+ GGML_ASSERT (src0->type == src0_type);
1827+ GGML_ASSERT (ggml_is_contiguous (dst));
17811828
17821829 // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
17831830 // As long as dst is contiguous this does not matter though.
1784- GGML_ASSERT (ggml_is_contiguous (dst));
17851831
17861832 GGML_TENSOR_BINARY_OP_LOCALS
17871833
17881834 const int64_t ne_dst = ggml_nelements (dst);
1789-
17901835 cudaStream_t main_stream = ctx.stream ();
1791-
17921836 CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (), main_stream));
17931837
1794- const half * src0_f16 = (const half *) src0->data ;
17951838 float * dst_ddf = (float *) dst->data ;
1796-
1797- const half * src1_f16 = (const half *) src1->data ;
17981839 const size_t ts_src1 = ggml_type_size (src1->type );
17991840 GGML_ASSERT (nb10 == ts_src1);
18001841 int64_t s11 = nb11 / ts_src1;
18011842 int64_t s12 = nb12 / ts_src1;
18021843 int64_t s13 = nb13 / ts_src1;
1803- ggml_cuda_pool_alloc<half> src1_f16_alloc (ctx.pool ());
18041844
1805- // convert src1 to fp16
1806- if (src1->type != GGML_TYPE_F16) {
1807- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1808- const int64_t ne_src1 = ggml_nelements (src1);
1809- src1_f16_alloc.alloc (ne_src1);
1810- GGML_ASSERT (to_fp16_cuda != nullptr );
1845+ const cuda_t * src0_ptr = nullptr ;
1846+ const cuda_t * src1_ptr = nullptr ;
18111847
1812- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848+ ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
1849+ ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
1850+
1851+ // Handle src0
1852+ src0_ptr = (const cuda_t *) src0->data ;
1853+
1854+ // Handle src1 - convert if necessary
1855+ if (src1->type == src0_type) {
1856+ src1_ptr = (const cuda_t *) src1->data ;
1857+ } else {
1858+ // Convert src1 to target type using traits conversion functions
1859+ const int64_t ne_src1 = ggml_nelements (src1);
1860+ src1_alloc.alloc (ne_src1);
18131861
1814- src1_f16 = src1_f16_alloc.get ();
1862+ const auto convert_func = traits::get_nc_converter (src1->type );
1863+ GGML_ASSERT (convert_func != nullptr );
1864+ convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865+ src1_ptr = src1_alloc.get ();
18151866 s11 = ne10;
18161867 s12 = ne11*s11;
18171868 s13 = ne12*s12;
18181869 }
18191870
1820- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool ());
1871+ // Setup destination buffer
1872+ ggml_cuda_pool_alloc<cuda_t > dst_temp (ctx.pool ());
18211873 char * dst_t ;
1822-
1823- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824- cudaDataType_t cu_data_type = CUDA_R_16F;
1825-
1826- // dst strides
18271874 size_t nbd2 = dst->nb [2 ];
18281875 size_t nbd3 = dst->nb [3 ];
18291876
1830- const half alpha_f16 = 1 .0f ;
1831- const half beta_f16 = 0 .0f ;
1832-
1877+ cublasComputeType_t cu_compute_type = traits::compute_type;
1878+ cudaDataType_t cu_data_type = traits::data_type;
1879+ cudaDataType_t cu_data_type_a = traits::data_type;
1880+ cudaDataType_t cu_data_type_b = traits::data_type;
1881+ const void * alpha = traits::get_alpha ();
1882+ const void * beta = traits::get_beta ();
18331883 const float alpha_f32 = 1 .0f ;
1834- const float beta_f32 = 0 .0f ;
1835-
1836- const void * alpha = &alpha_f16;
1837- const void * beta = &beta_f16;
1884+ const float beta_f32 = 0 .0f ;
18381885
18391886 if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1840- dst_t = (char *) dst_f16.alloc (ne_dst);
1841-
1842- nbd2 /= sizeof (float ) / sizeof (half);
1843- nbd3 /= sizeof (float ) / sizeof (half);
1887+ if constexpr (src0_type == GGML_TYPE_F32) {
1888+ dst_t = (char *) dst_ddf; // Direct F32 output
1889+ } else {
1890+ dst_t = (char *) dst_temp.alloc (ne_dst);
1891+ nbd2 /= sizeof (float ) / sizeof (cuda_t );
1892+ nbd3 /= sizeof (float ) / sizeof (cuda_t );
1893+ }
18441894 } else {
18451895 dst_t = (char *) dst_ddf;
1846-
18471896 cu_compute_type = CUBLAS_COMPUTE_32F;
1848- cu_data_type = CUDA_R_32F;
1849-
1897+ cu_data_type = CUDA_R_32F;
18501898 alpha = &alpha_f32;
1851- beta = &beta_f32;
1899+ beta = &beta_f32;
18521900 }
18531901
18541902 int id = ggml_cuda_get_device ();
18551903 const int cc = ggml_cuda_info ().devices [id].cc ;
18561904 if (GGML_CUDA_CC_IS_CDNA (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
18571905 cu_compute_type = CUBLAS_COMPUTE_32F;
18581906 alpha = &alpha_f32;
1859- beta = &beta_f32;
1907+ beta = &beta_f32;
18601908 }
18611909
18621910 GGML_ASSERT (ne12 % ne02 == 0 );
@@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18661914 const int64_t r2 = ne12/ne02;
18671915 const int64_t r3 = ne13/ne03;
18681916
1869- #if 0
1870- // use cublasGemmEx
1871- {
1872- for (int i13 = 0; i13 < ne13; ++i13) {
1873- for (int i12 = 0; i12 < ne12; ++i12) {
1874- int i03 = i13 / r3;
1875- int i02 = i12 / r2;
1876-
1877- CUBLAS_CHECK(
1878- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1879- ne01, ne11, ne10,
1880- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1881- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1882- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1883- cu_compute_type,
1884- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1885- }
1886- }
1887- }
1888- #else
18891917 if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
18901918 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
18911919 // use cublasGemmStridedBatchedEx
18921920 CUBLAS_CHECK (
18931921 cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18941922 ne01, ne11, ne10,
1895- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896- src1_f16, CUDA_R_16F, s11, s12, // strideB
1897- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1923+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1925+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
18981926 ne12*ne13,
18991927 cu_compute_type,
19001928 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
19051933 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
19061934 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
19071935
1936+ size_t src1_stride_size = sizeof (cuda_t );
1937+
19081938 dim3 block_dims (ne13, ne12);
19091939 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1910- src0_f16, src1_f16 , dst_t ,
1940+ src0_ptr, src1_ptr , dst_t ,
19111941 ptrs_src.get (), ptrs_dst.get (),
19121942 ne12, ne13,
19131943 ne23,
19141944 nb02, nb03,
1915- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof (half) ,
1916- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof (half) ,
1945+ ( src1->type == src0_type) ? nb12 : s12*src1_stride_size ,
1946+ ( src1->type == src0_type) ? nb13 : s13*src1_stride_size ,
19171947 nbd2, nbd3,
19181948 r2, r3);
1949+
19191950 CUDA_CHECK (cudaGetLastError ());
19201951
19211952 CUBLAS_CHECK (
19221953 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19231954 ne01, ne11, ne10,
1924- alpha, (const void **) (ptrs_src.get () + 0 *ne23), CUDA_R_16F, nb01/nb00,
1925- (const void **) (ptrs_src.get () + 1 *ne23), CUDA_R_16F, s11,
1926- beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1955+ alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
1956+ (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
1957+ beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
19271958 ne23,
19281959 cu_compute_type,
19291960 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19301961 }
1931- #endif
19321962
1933- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1935- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
1963+ // Convert output back to F32 if needed
1964+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (traits::ggml_type_val);
1966+ to_fp32_cuda (dst_temp.get (), dst_ddf, ne_dst, main_stream);
1967+ }
1968+ }
1969+
1970+ static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972+
1973+ switch (src0->type ) {
1974+ case GGML_TYPE_F32:
1975+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976+ break ;
1977+ case GGML_TYPE_BF16:
1978+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979+ break ;
1980+ case GGML_TYPE_F16:
1981+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982+ break ;
1983+ default :
1984+ GGML_ABORT (" Unsupported type" );
19361985 }
19371986}
19381987
@@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19842033 // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
19852034 // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
19862035
2036+ // TODO update for generic tensor parallelism
2037+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2038+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2040+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041+
19872042 if (!split && use_mul_mat_vec) {
19882043 // the custom F16 vector kernel can be used over batched cuBLAS GEMM
19892044 // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19922047 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19932048 } else if (!split && use_mul_mat_q) {
19942049 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1995- } else if (!split && src0-> type == GGML_TYPE_F16 && (src1-> type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2050+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051+ && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19972052 // general KQ + KQV multi-batch without FlashAttention
19982053 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
19992054 } else if (use_mul_mat_vec) {
0 commit comments