Skip to content

Commit 4887fa5

Browse files
committed
CUDA: add bf16 and f32 support to cublas_mul_mat_batched
1 parent 63a7bb3 commit 4887fa5

File tree

4 files changed

+202
-51
lines changed

4 files changed

+202
-51
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,14 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
728728
return nullptr;
729729
}
730730
}
731+
732+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733+
switch (type) {
734+
case GGML_TYPE_F32:
735+
return convert_unary_cuda<float, nv_bfloat16>;
736+
case GGML_TYPE_F16:
737+
return convert_unary_cuda<half, nv_bfloat16>;
738+
default:
739+
return nullptr;
740+
}
741+
}

ggml/src/ggml-cuda/convert.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
2323
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
2424

2525
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
26+
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
2627
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
28+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 185 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,8 +1748,9 @@ static void ggml_cuda_op_mul_mat(
17481748
}
17491749
}
17501750

1751+
template<typename T>
17511752
static __global__ void k_compute_batched_ptrs(
1752-
const half * src0_as_f16, const half * src1_as_f16, char * dst,
1753+
const T * src0_as_f16, const T * src1_as_f16, char * dst,
17531754
const void ** ptrs_src, void ** ptrs_dst,
17541755
int64_t ne12, int64_t ne13,
17551756
int64_t ne23,
@@ -1777,7 +1778,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17771778
GGML_ASSERT(!ggml_is_transposed(src1));
17781779

17791780
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1780-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
1781+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
17811782

17821783
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
17831784
// As long as dst is contiguous this does not matter though.
@@ -1791,64 +1792,153 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17911792

17921793
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
17931794

1794-
const half * src0_f16 = (const half *) src0->data;
1795-
float * dst_ddf = (float *) dst->data;
1795+
const ggml_type src0_type = src0->type;
1796+
const bool use_f32_path = src0_type == GGML_TYPE_F32;
1797+
const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
17961798

1797-
const half * src1_f16 = (const half *) src1->data;
1799+
float * dst_ddf = (float *) dst->data;
17981800
const size_t ts_src1 = ggml_type_size(src1->type);
17991801
GGML_ASSERT(nb10 == ts_src1);
18001802
int64_t s11 = nb11 / ts_src1;
18011803
int64_t s12 = nb12 / ts_src1;
18021804
int64_t s13 = nb13 / ts_src1;
1805+
1806+
const half * src0_f16 = nullptr;
1807+
const half * src1_f16 = nullptr;
1808+
const nv_bfloat16 * src0_bf16 = nullptr;
1809+
const nv_bfloat16 * src1_bf16 = nullptr;
1810+
const float * src0_f32 = nullptr;
1811+
const float * src1_f32 = nullptr;
1812+
1813+
ggml_cuda_pool_alloc<half> src0_f16_alloc(ctx.pool());
18031814
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1815+
ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc(ctx.pool());
1816+
ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc(ctx.pool());
1817+
ggml_cuda_pool_alloc<float> src0_f32_alloc(ctx.pool());
1818+
ggml_cuda_pool_alloc<float> src1_f32_alloc(ctx.pool());
1819+
1820+
if (use_f32_path) {
1821+
// F32 path
1822+
src0_f32 = (const float *) src0->data;
1823+
if (src1->type == GGML_TYPE_F32) {
1824+
src1_f32 = (const float *) src1->data;
1825+
} else {
1826+
// Convert src1 to F32
1827+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
1828+
const int64_t ne_src1 = ggml_nelements(src1);
1829+
src1_f32_alloc.alloc(ne_src1);
1830+
GGML_ASSERT(to_fp32_cuda != nullptr);
18041831

1805-
// convert src1 to fp16
1806-
if (src1->type != GGML_TYPE_F16) {
1807-
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1808-
const int64_t ne_src1 = ggml_nelements(src1);
1809-
src1_f16_alloc.alloc(ne_src1);
1810-
GGML_ASSERT(to_fp16_cuda != nullptr);
1832+
to_fp32_cuda((const void*)((const char*)src1->data), src1_f32_alloc.get(), ne_src1, main_stream);
1833+
src1_f32 = src1_f32_alloc.get();
1834+
s11 = ne10;
1835+
s12 = ne11*s11;
1836+
s13 = ne12*s12;
1837+
}
1838+
} else if (use_bf16_path) {
1839+
// BF16 path
1840+
src0_bf16 = (const nv_bfloat16 *) src0->data;
1841+
if (src1->type == GGML_TYPE_BF16) {
1842+
src1_bf16 = (const nv_bfloat16 *) src1->data;
1843+
} else {
1844+
// Convert src1 to BF16
1845+
const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda(src1->type);
1846+
const int64_t ne_src1 = ggml_nelements(src1);
1847+
src1_bf16_alloc.alloc(ne_src1);
1848+
GGML_ASSERT(to_bf16_cuda != nullptr);
18111849

1812-
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1850+
to_bf16_cuda((const void*)((const char*)src1->data), src1_bf16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851+
src1_bf16 = src1_bf16_alloc.get();
1852+
s11 = ne10;
1853+
s12 = ne11*s11;
1854+
s13 = ne12*s12;
1855+
}
1856+
} else {
1857+
// F16 path (default)
1858+
src0_f16 = (const half *) src0->data;
1859+
if (src1->type == GGML_TYPE_F16) {
1860+
src1_f16 = (const half *) src1->data;
1861+
} else {
1862+
// Convert src1 to F16
1863+
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1864+
const int64_t ne_src1 = ggml_nelements(src1);
1865+
src1_f16_alloc.alloc(ne_src1);
1866+
GGML_ASSERT(to_fp16_cuda != nullptr);
18131867

1814-
src1_f16 = src1_f16_alloc.get();
1815-
s11 = ne10;
1816-
s12 = ne11*s11;
1817-
s13 = ne12*s12;
1868+
to_fp16_cuda((const void*)((const char*)src1->data), src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1869+
src1_f16 = src1_f16_alloc.get();
1870+
s11 = ne10;
1871+
s12 = ne11*s11;
1872+
s13 = ne12*s12;
1873+
}
18181874
}
18191875

18201876
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1877+
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool());
18211878
char * dst_t;
18221879

1823-
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824-
cudaDataType_t cu_data_type = CUDA_R_16F;
1880+
cublasComputeType_t cu_compute_type;
1881+
cudaDataType_t cu_data_type;
1882+
cudaDataType_t cu_data_type_a;
1883+
cudaDataType_t cu_data_type_b;
1884+
1885+
if (use_f32_path) {
1886+
cu_compute_type = CUBLAS_COMPUTE_32F;
1887+
cu_data_type = CUDA_R_32F;
1888+
cu_data_type_a = CUDA_R_32F;
1889+
cu_data_type_b = CUDA_R_32F;
1890+
} else if (use_bf16_path) {
1891+
cu_compute_type = CUBLAS_COMPUTE_32F;
1892+
cu_data_type = CUDA_R_16BF;
1893+
cu_data_type_a = CUDA_R_16BF;
1894+
cu_data_type_b = CUDA_R_16BF;
1895+
} else {
1896+
cu_compute_type = CUBLAS_COMPUTE_16F;
1897+
cu_data_type = CUDA_R_16F;
1898+
cu_data_type_a = CUDA_R_16F;
1899+
cu_data_type_b = CUDA_R_16F;
1900+
}
18251901

1826-
// dst strides
18271902
size_t nbd2 = dst->nb[2];
18281903
size_t nbd3 = dst->nb[3];
18291904

18301905
const half alpha_f16 = 1.0f;
18311906
const half beta_f16 = 0.0f;
1832-
18331907
const float alpha_f32 = 1.0f;
18341908
const float beta_f32 = 0.0f;
18351909

1836-
const void * alpha = &alpha_f16;
1837-
const void * beta = &beta_f16;
1910+
const void * alpha;
1911+
const void * beta;
18381912

1839-
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1840-
dst_t = (char *) dst_f16.alloc(ne_dst);
1913+
if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
1914+
alpha = &alpha_f32;
1915+
beta = &beta_f32;
1916+
} else if (use_bf16_path) {
1917+
alpha = &alpha_f32;
1918+
beta = &beta_f32;
1919+
} else {
1920+
alpha = &alpha_f16;
1921+
beta = &beta_f16;
1922+
}
18411923

1842-
nbd2 /= sizeof(float) / sizeof(half);
1843-
nbd3 /= sizeof(float) / sizeof(half);
1924+
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1925+
if (use_f32_path) {
1926+
dst_t = (char *) dst_ddf; // Direct F32 output
1927+
} else if (use_bf16_path) {
1928+
dst_t = (char *) dst_bf16.alloc(ne_dst);
1929+
nbd2 /= sizeof(float) / sizeof(nv_bfloat16);
1930+
nbd3 /= sizeof(float) / sizeof(nv_bfloat16);
1931+
} else {
1932+
dst_t = (char *) dst_f16.alloc(ne_dst);
1933+
nbd2 /= sizeof(float) / sizeof(half);
1934+
nbd3 /= sizeof(float) / sizeof(half);
1935+
}
18441936
} else {
18451937
dst_t = (char *) dst_ddf;
1846-
18471938
cu_compute_type = CUBLAS_COMPUTE_32F;
1848-
cu_data_type = CUDA_R_32F;
1849-
1939+
cu_data_type = CUDA_R_32F;
18501940
alpha = &alpha_f32;
1851-
beta = &beta_f32;
1941+
beta = &beta_f32;
18521942
}
18531943

18541944
int id = ggml_cuda_get_device();
@@ -1889,11 +1979,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18891979
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
18901980
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
18911981
// use cublasGemmStridedBatchedEx
1982+
const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
1983+
use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
1984+
const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
1985+
use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
1986+
18921987
CUBLAS_CHECK(
18931988
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18941989
ne01, ne11, ne10,
1895-
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896-
src1_f16, CUDA_R_16F, s11, s12, // strideB
1990+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1991+
src1_ptr, cu_data_type_b, s11, s12, // strideB
18971992
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
18981993
ne12*ne13,
18991994
cu_compute_type,
@@ -1905,34 +2000,74 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
19052000
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
19062001
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
19072002

2003+
const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
2004+
use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
2005+
const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
2006+
use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
2007+
2008+
size_t src1_stride_size = use_f32_path ? sizeof(float) :
2009+
use_bf16_path ? sizeof(nv_bfloat16) : sizeof(half);
2010+
19082011
dim3 block_dims(ne13, ne12);
1909-
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1910-
src0_f16, src1_f16, dst_t,
1911-
ptrs_src.get(), ptrs_dst.get(),
1912-
ne12, ne13,
1913-
ne23,
1914-
nb02, nb03,
1915-
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1916-
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1917-
nbd2, nbd3,
1918-
r2, r3);
2012+
if( use_f32_path ) {
2013+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2014+
(const float*)src0_ptr, (const float*)src1_ptr, dst_t,
2015+
ptrs_src.get(), ptrs_dst.get(),
2016+
ne12, ne13,
2017+
ne23,
2018+
nb02, nb03,
2019+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2020+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2021+
nbd2, nbd3,
2022+
r2, r3);
2023+
} else if (use_bf16_path) {
2024+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2025+
(const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t,
2026+
ptrs_src.get(), ptrs_dst.get(),
2027+
ne12, ne13,
2028+
ne23,
2029+
nb02, nb03,
2030+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2031+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2032+
nbd2, nbd3,
2033+
r2, r3);
2034+
} else {
2035+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2036+
(const half*)src0_ptr, (const half*)src1_ptr, dst_t,
2037+
ptrs_src.get(), ptrs_dst.get(),
2038+
ne12, ne13,
2039+
ne23,
2040+
nb02, nb03,
2041+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2042+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2043+
nbd2, nbd3,
2044+
r2, r3);
2045+
}
2046+
19192047
CUDA_CHECK(cudaGetLastError());
19202048

19212049
CUBLAS_CHECK(
19222050
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19232051
ne01, ne11, ne10,
1924-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1925-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
2052+
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
2053+
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
19262054
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
19272055
ne23,
19282056
cu_compute_type,
19292057
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19302058
}
19312059
#endif
19322060

1933-
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1935-
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
2061+
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
2062+
if (use_f32_path) {
2063+
//already in f32
2064+
} else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
2065+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
2066+
to_fp32_cuda(dst_bf16.get(), dst_ddf, ne_dst, main_stream);
2067+
} else if (cu_data_type == CUDA_R_16F) {
2068+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
2069+
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
2070+
}
19362071
}
19372072
}
19382073

@@ -1992,8 +2127,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19922127
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
19932128
} else if (!split && use_mul_mat_q) {
19942129
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1995-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996-
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2130+
} else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
2131+
&& (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
2132+
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19972133
// general KQ + KQV multi-batch without FlashAttention
19982134
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19992135
} else if (use_mul_mat_vec) {

tests/test-backend-ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4425,8 +4425,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
44254425
for (auto nr : {1,4}) {
44264426
for (uint32_t m = 0; m < 2; ++m) {
44274427
for (uint32_t k = 0; k < 2; ++k) {
4428-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4429-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4428+
for(ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}){
4429+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4430+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4431+
}
44304432
}
44314433
}
44324434
}

0 commit comments

Comments
 (0)