Skip to content

Commit d3bdcf8

Browse files
committed
added BF16 support
1 parent 30d4607 commit d3bdcf8

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,7 @@ static void ggml_cpy_flt_cuda(
184184
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
185185
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
186186

187-
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
188-
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>)
189-
&& transpose){ //transpose
190-
187+
if (transpose){ //transpose
191188
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
192189
(ne00 + CUDA_CPY_TILE_DIM - 1) / CUDA_CPY_TILE_DIM,
193190
(ne/(ne00*ne01) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM );
@@ -436,7 +433,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
436433
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
437434
}
438435
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
439-
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
436+
if(src0->op == GGML_OP_TRANSPOSE){
437+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
438+
} else {
439+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
440+
}
440441
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
441442
if (contiguous_srcs) {
442443
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6523,6 +6523,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
65236523
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
65246524
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
65256525
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6526+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 4}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
65266527

65276528
test_cases.emplace_back(new test_cont());
65286529
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
@@ -7258,10 +7259,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
72587259
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
72597260
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
72607261
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
7262+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
72617263

72627264
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
72637265
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
72647266
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7267+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
72657268

72667269

72677270
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

0 commit comments

Comments
 (0)