@@ -1748,8 +1748,9 @@ static void ggml_cuda_op_mul_mat(
17481748 }
17491749}
17501750
1751+ template <typename T>
17511752static __global__ void k_compute_batched_ptrs (
1752- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1753+ const T * src0_as_f16, const T * src1_as_f16, char * dst,
17531754 const void ** ptrs_src, void ** ptrs_dst,
17541755 int64_t ne12, int64_t ne13,
17551756 int64_t ne23,
@@ -1777,7 +1778,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17771778 GGML_ASSERT (!ggml_is_transposed (src1));
17781779
17791780 GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ));
1780- GGML_ASSERT (src0->type == GGML_TYPE_F16);
1781+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0-> type == GGML_TYPE_BF16 || src0-> type == GGML_TYPE_F32 );
17811782
17821783 // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
17831784 // As long as dst is contiguous this does not matter though.
@@ -1791,64 +1792,153 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17911792
17921793 CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (), main_stream));
17931794
1794- const half * src0_f16 = (const half *) src0->data ;
1795- float * dst_ddf = (float *) dst->data ;
1795+ const ggml_type src0_type = src0->type ;
1796+ const bool use_f32_path = src0_type == GGML_TYPE_F32;
1797+ const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
17961798
1797- const half * src1_f16 = (const half *) src1 ->data ;
1799+ float * dst_ddf = (float *) dst ->data ;
17981800 const size_t ts_src1 = ggml_type_size (src1->type );
17991801 GGML_ASSERT (nb10 == ts_src1);
18001802 int64_t s11 = nb11 / ts_src1;
18011803 int64_t s12 = nb12 / ts_src1;
18021804 int64_t s13 = nb13 / ts_src1;
1805+
1806+ const half * src0_f16 = nullptr ;
1807+ const half * src1_f16 = nullptr ;
1808+ const nv_bfloat16 * src0_bf16 = nullptr ;
1809+ const nv_bfloat16 * src1_bf16 = nullptr ;
1810+ const float * src0_f32 = nullptr ;
1811+ const float * src1_f32 = nullptr ;
1812+
1813+ ggml_cuda_pool_alloc<half> src0_f16_alloc (ctx.pool ());
18031814 ggml_cuda_pool_alloc<half> src1_f16_alloc (ctx.pool ());
1815+ ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc (ctx.pool ());
1816+ ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc (ctx.pool ());
1817+ ggml_cuda_pool_alloc<float > src0_f32_alloc (ctx.pool ());
1818+ ggml_cuda_pool_alloc<float > src1_f32_alloc (ctx.pool ());
1819+
1820+ if (use_f32_path) {
1821+ // F32 path
1822+ src0_f32 = (const float *) src0->data ;
1823+ if (src1->type == GGML_TYPE_F32) {
1824+ src1_f32 = (const float *) src1->data ;
1825+ } else {
1826+ // Convert src1 to F32
1827+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src1->type );
1828+ const int64_t ne_src1 = ggml_nelements (src1);
1829+ src1_f32_alloc.alloc (ne_src1);
1830+ GGML_ASSERT (to_fp32_cuda != nullptr );
18041831
1805- // convert src1 to fp16
1806- if (src1->type != GGML_TYPE_F16) {
1807- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1808- const int64_t ne_src1 = ggml_nelements (src1);
1809- src1_f16_alloc.alloc (ne_src1);
1810- GGML_ASSERT (to_fp16_cuda != nullptr );
1832+ to_fp32_cuda ((const void *)((const char *)src1->data ), src1_f32_alloc.get (), ne_src1, main_stream);
1833+ src1_f32 = src1_f32_alloc.get ();
1834+ s11 = ne10;
1835+ s12 = ne11*s11;
1836+ s13 = ne12*s12;
1837+ }
1838+ } else if (use_bf16_path) {
1839+ // BF16 path
1840+ src0_bf16 = (const nv_bfloat16 *) src0->data ;
1841+ if (src1->type == GGML_TYPE_BF16) {
1842+ src1_bf16 = (const nv_bfloat16 *) src1->data ;
1843+ } else {
1844+ // Convert src1 to BF16
1845+ const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda (src1->type );
1846+ const int64_t ne_src1 = ggml_nelements (src1);
1847+ src1_bf16_alloc.alloc (ne_src1);
1848+ GGML_ASSERT (to_bf16_cuda != nullptr );
18111849
1812- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1850+ to_bf16_cuda ((const void *)((const char *)src1->data ), src1_bf16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851+ src1_bf16 = src1_bf16_alloc.get ();
1852+ s11 = ne10;
1853+ s12 = ne11*s11;
1854+ s13 = ne12*s12;
1855+ }
1856+ } else {
1857+ // F16 path (default)
1858+ src0_f16 = (const half *) src0->data ;
1859+ if (src1->type == GGML_TYPE_F16) {
1860+ src1_f16 = (const half *) src1->data ;
1861+ } else {
1862+ // Convert src1 to F16
1863+ const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1864+ const int64_t ne_src1 = ggml_nelements (src1);
1865+ src1_f16_alloc.alloc (ne_src1);
1866+ GGML_ASSERT (to_fp16_cuda != nullptr );
18131867
1814- src1_f16 = src1_f16_alloc.get ();
1815- s11 = ne10;
1816- s12 = ne11*s11;
1817- s13 = ne12*s12;
1868+ to_fp16_cuda ((const void *)((const char *)src1->data ), src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1869+ src1_f16 = src1_f16_alloc.get ();
1870+ s11 = ne10;
1871+ s12 = ne11*s11;
1872+ s13 = ne12*s12;
1873+ }
18181874 }
18191875
18201876 ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool ());
1877+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16 (ctx.pool ());
18211878 char * dst_t ;
18221879
1823- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824- cudaDataType_t cu_data_type = CUDA_R_16F;
1880+ cublasComputeType_t cu_compute_type;
1881+ cudaDataType_t cu_data_type;
1882+ cudaDataType_t cu_data_type_a;
1883+ cudaDataType_t cu_data_type_b;
1884+
1885+ if (use_f32_path) {
1886+ cu_compute_type = CUBLAS_COMPUTE_32F;
1887+ cu_data_type = CUDA_R_32F;
1888+ cu_data_type_a = CUDA_R_32F;
1889+ cu_data_type_b = CUDA_R_32F;
1890+ } else if (use_bf16_path) {
1891+ cu_compute_type = CUBLAS_COMPUTE_32F;
1892+ cu_data_type = CUDA_R_16BF;
1893+ cu_data_type_a = CUDA_R_16BF;
1894+ cu_data_type_b = CUDA_R_16BF;
1895+ } else {
1896+ cu_compute_type = CUBLAS_COMPUTE_16F;
1897+ cu_data_type = CUDA_R_16F;
1898+ cu_data_type_a = CUDA_R_16F;
1899+ cu_data_type_b = CUDA_R_16F;
1900+ }
18251901
1826- // dst strides
18271902 size_t nbd2 = dst->nb [2 ];
18281903 size_t nbd3 = dst->nb [3 ];
18291904
18301905 const half alpha_f16 = 1 .0f ;
18311906 const half beta_f16 = 0 .0f ;
1832-
18331907 const float alpha_f32 = 1 .0f ;
18341908 const float beta_f32 = 0 .0f ;
18351909
1836- const void * alpha = &alpha_f16 ;
1837- const void * beta = &beta_f16 ;
1910+ const void * alpha;
1911+ const void * beta;
18381912
1839- if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1840- dst_t = (char *) dst_f16.alloc (ne_dst);
1913+ if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
1914+ alpha = &alpha_f32;
1915+ beta = &beta_f32;
1916+ } else if (use_bf16_path) {
1917+ alpha = &alpha_f32;
1918+ beta = &beta_f32;
1919+ } else {
1920+ alpha = &alpha_f16;
1921+ beta = &beta_f16;
1922+ }
18411923
1842- nbd2 /= sizeof (float ) / sizeof (half);
1843- nbd3 /= sizeof (float ) / sizeof (half);
1924+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1925+ if (use_f32_path) {
1926+ dst_t = (char *) dst_ddf; // Direct F32 output
1927+ } else if (use_bf16_path) {
1928+ dst_t = (char *) dst_bf16.alloc (ne_dst);
1929+ nbd2 /= sizeof (float ) / sizeof (nv_bfloat16);
1930+ nbd3 /= sizeof (float ) / sizeof (nv_bfloat16);
1931+ } else {
1932+ dst_t = (char *) dst_f16.alloc (ne_dst);
1933+ nbd2 /= sizeof (float ) / sizeof (half);
1934+ nbd3 /= sizeof (float ) / sizeof (half);
1935+ }
18441936 } else {
18451937 dst_t = (char *) dst_ddf;
1846-
18471938 cu_compute_type = CUBLAS_COMPUTE_32F;
1848- cu_data_type = CUDA_R_32F;
1849-
1939+ cu_data_type = CUDA_R_32F;
18501940 alpha = &alpha_f32;
1851- beta = &beta_f32;
1941+ beta = &beta_f32;
18521942 }
18531943
18541944 int id = ggml_cuda_get_device ();
@@ -1889,11 +1979,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18891979 if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
18901980 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
18911981 // use cublasGemmStridedBatchedEx
1982+ const void * src0_ptr = use_f32_path ? (const void *)src0_f32 :
1983+ use_bf16_path ? (const void *)src0_bf16 : (const void *)src0_f16;
1984+ const void * src1_ptr = use_f32_path ? (const void *)src1_f32 :
1985+ use_bf16_path ? (const void *)src1_bf16 : (const void *)src1_f16;
1986+
18921987 CUBLAS_CHECK (
18931988 cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18941989 ne01, ne11, ne10,
1895- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896- src1_f16, CUDA_R_16F, s11, s12, // strideB
1990+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1991+ src1_ptr, cu_data_type_b, s11, s12, // strideB
18971992 beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
18981993 ne12*ne13,
18991994 cu_compute_type,
@@ -1905,34 +2000,74 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
19052000 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
19062001 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
19072002
2003+ const void * src0_ptr = use_f32_path ? (const void *)src0_f32 :
2004+ use_bf16_path ? (const void *)src0_bf16 : (const void *)src0_f16;
2005+ const void * src1_ptr = use_f32_path ? (const void *)src1_f32 :
2006+ use_bf16_path ? (const void *)src1_bf16 : (const void *)src1_f16;
2007+
2008+ size_t src1_stride_size = use_f32_path ? sizeof (float ) :
2009+ use_bf16_path ? sizeof (nv_bfloat16) : sizeof (half);
2010+
19082011 dim3 block_dims (ne13, ne12);
1909- k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1910- src0_f16, src1_f16, dst_t ,
1911- ptrs_src.get (), ptrs_dst.get (),
1912- ne12, ne13,
1913- ne23,
1914- nb02, nb03,
1915- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof (half),
1916- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof (half),
1917- nbd2, nbd3,
1918- r2, r3);
2012+ if ( use_f32_path ) {
2013+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2014+ (const float *)src0_ptr, (const float *)src1_ptr, dst_t ,
2015+ ptrs_src.get (), ptrs_dst.get (),
2016+ ne12, ne13,
2017+ ne23,
2018+ nb02, nb03,
2019+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2020+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2021+ nbd2, nbd3,
2022+ r2, r3);
2023+ } else if (use_bf16_path) {
2024+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2025+ (const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t ,
2026+ ptrs_src.get (), ptrs_dst.get (),
2027+ ne12, ne13,
2028+ ne23,
2029+ nb02, nb03,
2030+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2031+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2032+ nbd2, nbd3,
2033+ r2, r3);
2034+ } else {
2035+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2036+ (const half*)src0_ptr, (const half*)src1_ptr, dst_t ,
2037+ ptrs_src.get (), ptrs_dst.get (),
2038+ ne12, ne13,
2039+ ne23,
2040+ nb02, nb03,
2041+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2042+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2043+ nbd2, nbd3,
2044+ r2, r3);
2045+ }
2046+
19192047 CUDA_CHECK (cudaGetLastError ());
19202048
19212049 CUBLAS_CHECK (
19222050 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19232051 ne01, ne11, ne10,
1924- alpha, (const void **) (ptrs_src.get () + 0 *ne23), CUDA_R_16F, nb01/nb00,
1925- (const void **) (ptrs_src.get () + 1 *ne23), CUDA_R_16F, s11,
2052+ alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
2053+ (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
19262054 beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
19272055 ne23,
19282056 cu_compute_type,
19292057 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19302058 }
19312059#endif
19322060
1933- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1935- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
2061+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
2062+ if (use_f32_path) {
2063+ // already in f32
2064+ } else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
2065+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
2066+ to_fp32_cuda (dst_bf16.get (), dst_ddf, ne_dst, main_stream);
2067+ } else if (cu_data_type == CUDA_R_16F) {
2068+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
2069+ to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
2070+ }
19362071 }
19372072}
19382073
@@ -1992,8 +2127,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19922127 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19932128 } else if (!split && use_mul_mat_q) {
19942129 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1995- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2130+ } else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
2131+ && (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
2132+ && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19972133 // general KQ + KQV multi-batch without FlashAttention
19982134 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
19992135 } else if (use_mul_mat_vec) {
0 commit comments