Skip to content

Commit d8c0a3c

Browse files
committed
Revert "CUDA: add bf16 and f32 support to cublas_mul_mat_batched (ggml-org#14361)"
1 parent 186227f commit d8c0a3c

File tree

3 files changed

+76
-158
lines changed

3 files changed

+76
-158
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -728,25 +728,3 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
728728
return nullptr;
729729
}
730730
}
731-
732-
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733-
switch (type) {
734-
case GGML_TYPE_F32:
735-
return convert_unary_cuda<float, nv_bfloat16>;
736-
case GGML_TYPE_F16:
737-
return convert_unary_cuda<half, nv_bfloat16>;
738-
default:
739-
return nullptr;
740-
}
741-
}
742-
743-
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744-
switch (type) {
745-
case GGML_TYPE_F16:
746-
return convert_unary_cuda<half, float>;
747-
case GGML_TYPE_BF16:
748-
return convert_unary_cuda<nv_bfloat16, float>;
749-
default:
750-
return nullptr;
751-
}
752-
}

ggml/src/ggml-cuda/convert.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,5 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
2222
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
2323
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
2424

25-
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
2625
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
27-
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
28-
29-
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
3026
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
31-
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 76 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ static void ggml_cuda_op_mul_mat(
17501750
}
17511751

17521752
static __global__ void k_compute_batched_ptrs(
1753-
const void * src0_as_f16, const void * src1_as_f16, char * dst,
1753+
const half * src0_as_f16, const half * src1_as_f16, char * dst,
17541754
const void ** ptrs_src, void ** ptrs_dst,
17551755
int64_t ne12, int64_t ne13,
17561756
int64_t ne23,
@@ -1773,139 +1773,91 @@ static __global__ void k_compute_batched_ptrs(
17731773
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
17741774
}
17751775

1776-
// Type traits for mapping ggml types to CUDA/cuBLAS types
1777-
template<ggml_type T>
1778-
struct batched_mul_mat_traits;
1779-
1780-
template<>
1781-
struct batched_mul_mat_traits<GGML_TYPE_F32> {
1782-
using cuda_type = float;
1783-
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1784-
static inline const cudaDataType_t data_type = CUDA_R_32F;
1785-
static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1786-
static inline const float alpha = 1.0f;
1787-
static inline const float beta = 0.0f;
1788-
static inline const void* get_alpha() { static const float val = alpha; return &val; }
1789-
static inline const void* get_beta() { static const float val = beta; return &val; }
1790-
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1791-
};
1792-
1793-
template<>
1794-
struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1795-
using cuda_type = nv_bfloat16;
1796-
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1797-
static inline const cudaDataType_t data_type = CUDA_R_16BF;
1798-
static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1799-
static inline const float alpha = 1.0f;
1800-
static inline const float beta = 0.0f;
1801-
static inline const void* get_alpha() { static const float val = alpha; return &val; }
1802-
static inline const void* get_beta() { static const float val = beta; return &val; }
1803-
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1804-
};
1805-
1806-
template<>
1807-
struct batched_mul_mat_traits<GGML_TYPE_F16> {
1808-
using cuda_type = half;
1809-
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1810-
static inline const cudaDataType_t data_type = CUDA_R_16F;
1811-
static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1812-
static inline const half alpha = 1.0;
1813-
static inline const half beta = 0.0;
1814-
static inline const void* get_alpha() { static const half val = alpha; return &val; }
1815-
static inline const void* get_beta() { static const half val = beta; return &val; }
1816-
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1817-
};
1818-
1819-
template<ggml_type src0_type>
1820-
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1821-
using traits = batched_mul_mat_traits<src0_type>;
1822-
using cuda_t = typename traits::cuda_type;
1823-
1776+
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
18241777
GGML_ASSERT(!ggml_is_transposed(src0));
18251778
GGML_ASSERT(!ggml_is_transposed(src1));
1779+
18261780
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1827-
GGML_ASSERT(src0->type == src0_type);
1828-
GGML_ASSERT(ggml_is_contiguous(dst));
1781+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
18291782

18301783
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
18311784
// As long as dst is contiguous this does not matter though.
1785+
GGML_ASSERT(ggml_is_contiguous(dst));
18321786

18331787
GGML_TENSOR_BINARY_OP_LOCALS
18341788

18351789
const int64_t ne_dst = ggml_nelements(dst);
1790+
18361791
cudaStream_t main_stream = ctx.stream();
1792+
18371793
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
18381794

1795+
const half * src0_f16 = (const half *) src0->data;
18391796
float * dst_ddf = (float *) dst->data;
1797+
1798+
const half * src1_f16 = (const half *) src1->data;
18401799
const size_t ts_src1 = ggml_type_size(src1->type);
18411800
GGML_ASSERT(nb10 == ts_src1);
18421801
int64_t s11 = nb11 / ts_src1;
18431802
int64_t s12 = nb12 / ts_src1;
18441803
int64_t s13 = nb13 / ts_src1;
1804+
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
18451805

1846-
const cuda_t * src0_ptr = nullptr;
1847-
const cuda_t * src1_ptr = nullptr;
1848-
1849-
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1850-
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1851-
1852-
// Handle src0
1853-
src0_ptr = (const cuda_t *) src0->data;
1854-
1855-
// Handle src1 - convert if necessary
1856-
if (src1->type == src0_type) {
1857-
src1_ptr = (const cuda_t *) src1->data;
1858-
} else {
1859-
// Convert src1 to target type using traits conversion functions
1806+
// convert src1 to fp16
1807+
if (src1->type != GGML_TYPE_F16) {
1808+
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
18601809
const int64_t ne_src1 = ggml_nelements(src1);
1861-
src1_alloc.alloc(ne_src1);
1810+
src1_f16_alloc.alloc(ne_src1);
1811+
GGML_ASSERT(to_fp16_cuda != nullptr);
18621812

1863-
const auto convert_func = traits::get_nc_converter(src1->type);
1864-
GGML_ASSERT(convert_func != nullptr);
1865-
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1866-
src1_ptr = src1_alloc.get();
1813+
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1814+
1815+
src1_f16 = src1_f16_alloc.get();
18671816
s11 = ne10;
18681817
s12 = ne11*s11;
18691818
s13 = ne12*s12;
18701819
}
18711820

1872-
// Setup destination buffer
1873-
ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1821+
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
18741822
char * dst_t;
1823+
1824+
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1825+
cudaDataType_t cu_data_type = CUDA_R_16F;
1826+
1827+
// dst strides
18751828
size_t nbd2 = dst->nb[2];
18761829
size_t nbd3 = dst->nb[3];
18771830

1878-
cublasComputeType_t cu_compute_type = traits::compute_type;
1879-
cudaDataType_t cu_data_type = traits::data_type;
1880-
cudaDataType_t cu_data_type_a = traits::data_type;
1881-
cudaDataType_t cu_data_type_b = traits::data_type;
1882-
const void * alpha = traits::get_alpha();
1883-
const void * beta = traits::get_beta();
1831+
const half alpha_f16 = 1.0f;
1832+
const half beta_f16 = 0.0f;
1833+
18841834
const float alpha_f32 = 1.0f;
1885-
const float beta_f32 = 0.0f;
1835+
const float beta_f32 = 0.0f;
1836+
1837+
const void * alpha = &alpha_f16;
1838+
const void * beta = &beta_f16;
18861839

18871840
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1888-
if constexpr (src0_type == GGML_TYPE_F32) {
1889-
dst_t = (char *) dst_ddf; // Direct F32 output
1890-
} else {
1891-
dst_t = (char *) dst_temp.alloc(ne_dst);
1892-
nbd2 /= sizeof(float) / sizeof(cuda_t);
1893-
nbd3 /= sizeof(float) / sizeof(cuda_t);
1894-
}
1841+
dst_t = (char *) dst_f16.alloc(ne_dst);
1842+
1843+
nbd2 /= sizeof(float) / sizeof(half);
1844+
nbd3 /= sizeof(float) / sizeof(half);
18951845
} else {
18961846
dst_t = (char *) dst_ddf;
1847+
18971848
cu_compute_type = CUBLAS_COMPUTE_32F;
1898-
cu_data_type = CUDA_R_32F;
1849+
cu_data_type = CUDA_R_32F;
1850+
18991851
alpha = &alpha_f32;
1900-
beta = &beta_f32;
1852+
beta = &beta_f32;
19011853
}
19021854

19031855
int id = ggml_cuda_get_device();
19041856
const int cc = ggml_cuda_info().devices[id].cc;
19051857
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
19061858
cu_compute_type = CUBLAS_COMPUTE_32F;
19071859
alpha = &alpha_f32;
1908-
beta = &beta_f32;
1860+
beta = &beta_f32;
19091861
}
19101862

19111863
GGML_ASSERT(ne12 % ne02 == 0);
@@ -1915,15 +1867,35 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19151867
const int64_t r2 = ne12/ne02;
19161868
const int64_t r3 = ne13/ne03;
19171869

1870+
#if 0
1871+
// use cublasGemmEx
1872+
{
1873+
for (int i13 = 0; i13 < ne13; ++i13) {
1874+
for (int i12 = 0; i12 < ne12; ++i12) {
1875+
int i03 = i13 / r3;
1876+
int i02 = i12 / r2;
1877+
1878+
CUBLAS_CHECK(
1879+
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1880+
ne01, ne11, ne10,
1881+
alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1882+
src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1883+
beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1884+
cu_compute_type,
1885+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1886+
}
1887+
}
1888+
}
1889+
#else
19181890
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
19191891
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19201892
// use cublasGemmStridedBatchedEx
19211893
CUBLAS_CHECK(
19221894
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19231895
ne01, ne11, ne10,
1924-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1925-
src1_ptr, cu_data_type_b, s11, s12, // strideB
1926-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1896+
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1897+
src1_f16, CUDA_R_16F, s11, s12, // strideB
1898+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19271899
ne12*ne13,
19281900
cu_compute_type,
19291901
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1934,55 +1906,34 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19341906
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
19351907
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
19361908

1937-
size_t src1_stride_size = sizeof(cuda_t);
1938-
19391909
dim3 block_dims(ne13, ne12);
19401910
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1941-
src0_ptr, src1_ptr, dst_t,
1911+
src0_f16, src1_f16, dst_t,
19421912
ptrs_src.get(), ptrs_dst.get(),
19431913
ne12, ne13,
19441914
ne23,
19451915
nb02, nb03,
1946-
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1947-
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1916+
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1917+
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
19481918
nbd2, nbd3,
19491919
r2, r3);
1950-
19511920
CUDA_CHECK(cudaGetLastError());
19521921

19531922
CUBLAS_CHECK(
19541923
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19551924
ne01, ne11, ne10,
1956-
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1957-
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1958-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1925+
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1926+
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1927+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
19591928
ne23,
19601929
cu_compute_type,
19611930
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19621931
}
1932+
#endif
19631933

1964-
// Convert output back to F32 if needed
1965-
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1966-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1967-
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1968-
}
1969-
}
1970-
1971-
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1972-
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1973-
1974-
switch (src0->type) {
1975-
case GGML_TYPE_F32:
1976-
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1977-
break;
1978-
case GGML_TYPE_BF16:
1979-
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1980-
break;
1981-
case GGML_TYPE_F16:
1982-
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1983-
break;
1984-
default:
1985-
GGML_ABORT("Unsupported type");
1934+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1935+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1936+
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
19861937
}
19871938
}
19881939

@@ -2034,12 +1985,6 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20341985
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20351986
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20361987

2037-
//TODO update for generic tensor parallelism
2038-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2039-
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2040-
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2041-
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2042-
20431988
if (!split && use_mul_mat_vec) {
20441989
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
20451990
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -2048,8 +1993,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20481993
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
20491994
} else if (!split && use_mul_mat_q) {
20501995
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
2051-
} else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2052-
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1996+
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1997+
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
20531998
// general KQ + KQV multi-batch without FlashAttention
20541999
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
20552000
} else if (use_mul_mat_vec) {

0 commit comments

Comments
 (0)