Skip to content

Commit 2ec8689

Browse files
Merge pull request #96 from BradHutchings/work-in-progress
Work in progress
2 parents 271b22e + 0b0b371 commit 2ec8689

File tree

19 files changed

+462
-167
lines changed

19 files changed

+462
-167
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ To get this from the llama.cpp source base, there are few files that need to be
4848

4949
3. [src/llama-context.cpp](src/llama-context-mmojo.cpp) -- COSMOCC doesn't have std::fill in its Standard Templates Library.
5050

51-
4. [src/llama-hparams.cpp](src/llama-hapams-mmojo.cpp) -- COSMOCC doesn't have std::max in its Standard Templates Library.
51+
4. [src/llama-hparams.cpp](src/llama-hparams-mmojo.cpp) -- COSMOCC doesn't have std::max in its Standard Templates Library.
5252

5353
5. [tools/server/server.cpp](tools/server/server-mmojo.cpp) -- Support embedded or adjacent "args" file, fix Cosmo name conflict with "defer" task member, add additional meta data to `model_meta`.
5454

common/arg.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1612,7 +1612,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
16121612
[](common_params & params, const std::string & value) {
16131613
params.antiprompt.emplace_back(value);
16141614
}
1615-
).set_examples({LLAMA_EXAMPLE_MAIN}));
1615+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
16161616
add_opt(common_arg(
16171617
{"-sp", "--special"},
16181618
string_format("special tokens output enabled (default: %s)", params.special ? "true" : "false"),
@@ -2655,6 +2655,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26552655
params.i_chunk = value;
26562656
}
26572657
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
2658+
add_opt(common_arg(
2659+
{"--show-statistics"},
2660+
string_format("show imatrix statistics and then exit (default: %s)", params.show_statistics ? "true" : "false"),
2661+
[](common_params & params) {
2662+
params.show_statistics = true;
2663+
}
2664+
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
26582665
add_opt(common_arg(
26592666
{"--parse-special"},
26602667
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),

common/common.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,10 @@ struct common_params {
432432
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
433433
int32_t i_chunk = 0; // start processing from this chunk
434434

435-
bool process_output = false; // collect data for the output tensor
436-
bool compute_ppl = true; // whether to compute perplexity
437-
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
435+
bool process_output = false; // collect data for the output tensor
436+
bool compute_ppl = true; // whether to compute perplexity
437+
bool show_statistics = false; // show imatrix statistics per tensor
438+
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
438439

439440
// cvector-generator params
440441
int n_pca_batch = 100;

completion-ui/completion/scripts.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,7 @@ async function GetModelInfoFromServer() {
10581058
metadata = data0.meta;
10591059
const modelName = metadata["general.name"];
10601060
const n_ctx_train = metadata["n_ctx_train"];
1061+
const n_ctx = metadata["n_ctx"];
10611062

10621063
if (kLogging) console.log("json.data[0]:\n");
10631064
if (kLogging) console.log(data0);
@@ -1071,7 +1072,7 @@ async function GetModelInfoFromServer() {
10711072
if (kLogging) console.log("meta[\"n_ctx_train\"]:\n");
10721073
if (kLogging) console.log(n_ctx_train);
10731074

1074-
contextWindowSize = n_ctx_train;
1075+
contextWindowSize = n_ctx;
10751076
elements.model.innerHTML = modelName;
10761077
}
10771078
catch(exc) {

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND)
102102
if (GGML_STATIC)
103103
if (WIN32)
104104
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
105-
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
105+
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
106106
else ()
107-
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
107+
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
108108
endif()
109109
else()
110-
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
110+
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
111111
endif()
112112

113113
if (GGML_CUDA_NO_VMM)

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,13 @@
22

33
#include "ggml-common.h"
44

5-
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
6-
*dst = *src;
7-
}
8-
9-
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
10-
*dst = __float2half(*src);
11-
}
12-
13-
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
14-
*dst = *src;
15-
}
16-
17-
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
18-
*dst = *src;
19-
}
20-
21-
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
22-
*dst = *src;
5+
template<typename src_t, typename dst_t>
6+
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
7+
if constexpr (std::is_same_v<src_t, dst_t>) {
8+
*dst = *src;
9+
} else {
10+
*dst = float(*src);
11+
}
2312
}
2413

2514
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
@@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
230219
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
231220
}
232221

233-
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
234-
convert_f32_f32((const float *)cxi, (float *)cdsti);
235-
}
236-
237-
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
238-
convert_f32_f16((const float *)cxi, (half *)cdsti);
239-
}
240-
241-
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
242-
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
243-
}
244-
245-
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
246-
convert_f16_f16((const half *)cxi, (half *)cdsti);
247-
}
248-
249-
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
250-
convert_f16_f32((const half *)cxi, (float *)cdsti);
222+
template<typename src_t, typename dst_t>
223+
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
224+
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
251225
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

1010
template <cpy_kernel_t cpy_1>
11-
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
12-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14-
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
11+
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
12+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
1515
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
1616

1717
if (i >= ne) {
@@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139139
#endif
140140
}
141141

142-
static void ggml_cpy_f16_f32_cuda(
142+
template<typename src_t, typename dst_t>
143+
static void ggml_cpy_flt_cuda(
143144
const char * cx, char * cdst, const int ne,
144145
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145146
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146147

147148
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
148-
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150-
}
151-
152-
static void ggml_cpy_f32_f32_cuda(
153-
const char * cx, char * cdst, const int ne,
154-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156-
157-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
158-
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
159-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160-
}
161-
162-
static void ggml_cpy_f32_bf16_cuda(
163-
const char * cx, char * cdst, const int ne,
164-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166-
167-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
168-
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
169-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170-
}
171-
172-
static void ggml_cpy_f32_f16_cuda(
173-
const char * cx, char * cdst, const int ne,
174-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176-
177-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
178-
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149+
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
179150
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180151
}
181152

@@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
307278
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
308279
}
309280

310-
static void ggml_cpy_f16_f16_cuda(
311-
const char * cx, char * cdst, const int ne,
312-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
313-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
314-
315-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
316-
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
317-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318-
}
319-
320281
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321282
const int64_t ne = ggml_nelements(src0);
322283
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372333
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
373334
}
374335
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375-
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
336+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376337
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
377-
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
338+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
378339
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
379-
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
380341
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
381342
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
382343
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
403364
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
404365
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405366
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406-
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
368+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
369+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
407370
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408-
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
371+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
372+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
373+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
374+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
375+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
377+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
409378
} else {
410379
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
411380
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
430399
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
431400
return nullptr;
432401
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
433-
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
402+
return (void*) cpy_flt<cpy_1_flt<float, float>>;
434403
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
435-
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
404+
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
436405
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
437-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
406+
return (void*) cpy_flt<cpy_1_flt<float, half>>;
438407
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439408
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
440409
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458427
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459428
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460429
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
430+
return (void*) cpy_flt<cpy_1_flt<half, half>>;
431+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
432+
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
462433
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463-
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
434+
return (void*) cpy_flt<cpy_1_flt<half, float>>;
435+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
436+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
437+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
438+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
439+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
440+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
464441
} else {
465442
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
466443
ggml_type_name(src0->type), ggml_type_name(src1->type));

0 commit comments

Comments
 (0)