Skip to content

Commit 3a50a13

Browse files
committed
ggml : add backend registry / device interfaces to BLAS backend
1 parent 71967c2 commit 3a50a13

File tree

5 files changed

+235
-58
lines changed

5 files changed

+235
-58
lines changed

ggml/include/ggml-backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ extern "C" {
168168

169169
// Functions that may be obtained using ggml_backend_reg_get_proc_address
170170
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(const float *);
171+
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t, int);
171172

172173
//
173174
// Backend registry

ggml/include/ggml-blas.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ GGML_API bool ggml_backend_is_blas(ggml_backend_t backend);
1717
// for openblas and blis, this will also set the number of threads used for blas operations
1818
GGML_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
1919

20+
GGML_API ggml_backend_reg_t ggml_backend_blas_reg(void);
21+
2022

2123
#ifdef __cplusplus
2224
}

ggml/src/CMakeLists.txt

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,22 +190,24 @@ if (GGML_BLAS)
190190
# see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
191191
find_package(PkgConfig REQUIRED)
192192
if (${GGML_BLAS_VENDOR} MATCHES "Generic")
193-
pkg_check_modules(DepBLAS REQUIRED blas)
193+
pkg_check_modules(DepBLAS blas)
194194
elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
195195
# As of openblas v0.3.22, the 64-bit is named openblas64.pc
196196
pkg_check_modules(DepBLAS openblas64)
197197
if (NOT DepBLAS_FOUND)
198-
pkg_check_modules(DepBLAS REQUIRED openblas)
198+
pkg_check_modules(DepBLAS openblas)
199199
endif()
200200
elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
201-
pkg_check_modules(DepBLAS REQUIRED blis)
201+
add_compile_definitions(GGML_BLAS_USE_BLIS)
202+
pkg_check_modules(DepBLAS blis)
202203
elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
203-
pkg_check_modules(DepBLAS REQUIRED blas-atlas)
204+
pkg_check_modules(DepBLAS blas-atlas)
204205
elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
205-
pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
206+
pkg_check_modules(DepBLAS flexiblas_api)
206207
elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
208+
add_compile_definitions(GGML_BLAS_USE_MKL)
207209
# all Intel* libraries share the same include path
208-
pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
210+
pkg_check_modules(DepBLAS mkl-sdl)
209211
elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
210212
# this doesn't provide pkg-config
211213
# suggest to assign BLAS_INCLUDE_DIRS on your own

ggml/src/ggml-backend.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
525525
#include "ggml-cuda.h"
526526
#endif
527527

528+
#ifdef GGML_USE_BLAS
529+
#include "ggml-blas.h"
530+
#endif
531+
528532
struct ggml_backend_registry {
529533
std::vector<ggml_backend_reg_t> backends;
530534
std::vector<ggml_backend_dev_t> devices;
@@ -534,6 +538,10 @@ struct ggml_backend_registry {
534538
register_backend(ggml_backend_cuda_reg());
535539
#endif
536540

541+
#ifdef GGML_USE_BLAS
542+
register_backend(ggml_backend_blas_reg());
543+
#endif
544+
537545
register_backend(ggml_backend_cpu_reg());
538546

539547
// TODO: sycl, metal, vulkan, kompute, cann
@@ -1221,11 +1229,21 @@ static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg
12211229
GGML_UNUSED(index);
12221230
}
12231231

1232+
static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1233+
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
1234+
return (void *)ggml_backend_cpu_set_n_threads;
1235+
}
1236+
return NULL;
1237+
1238+
GGML_UNUSED(reg);
1239+
GGML_UNUSED(name);
1240+
}
1241+
12241242
static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
12251243
/* .get_name = */ ggml_backend_cpu_reg_get_name,
12261244
/* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
12271245
/* .get_device = */ ggml_backend_cpu_reg_get_device,
1228-
/* .get_proc_address = */ NULL,
1246+
/* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
12291247
};
12301248

12311249
ggml_backend_reg_t ggml_backend_cpu_reg(void) {

ggml/src/ggml-blas.cpp

Lines changed: 205 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,14 @@
1818
#endif
1919

2020
struct ggml_backend_blas_context {
21-
int n_threads = GGML_DEFAULT_N_THREADS;
21+
int n_threads = std::thread::hardware_concurrency()/2;
2222
std::unique_ptr<char[]> work_data;
2323
size_t work_size = 0;
2424
#ifndef GGML_USE_OPENMP
2525
std::vector<std::future<void>> tasks;
2626
#endif
2727
};
2828

29-
// helper function to determine if it is better to use BLAS or not
30-
// for large matrices, BLAS is faster
31-
static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
32-
const struct ggml_tensor * src0 = dst->src[0];
33-
const struct ggml_tensor * src1 = dst->src[1];
34-
35-
const int64_t ne10 = src1->ne[0];
36-
37-
const int64_t ne0 = dst->ne[0];
38-
const int64_t ne1 = dst->ne[1];
39-
40-
// TODO: find the optimal values for these
41-
if (ggml_is_contiguous(src0) &&
42-
ggml_is_contiguous(src1) &&
43-
src1->type == GGML_TYPE_F32 &&
44-
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
45-
46-
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
47-
return true;
48-
}
49-
50-
return false;
51-
}
52-
5329
static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
5430
const struct ggml_tensor * src0 = dst->src[0];
5531
const struct ggml_tensor * src1 = dst->src[1];
@@ -235,7 +211,7 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
235211

236212
// backend interface
237213

238-
static const char * ggml_backend_blas_name(ggml_backend_t backend) {
214+
static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
239215
return "BLAS";
240216

241217
GGML_UNUSED(backend);
@@ -285,29 +261,8 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
285261
GGML_UNUSED(backend);
286262
}
287263

288-
static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
289-
const struct ggml_tensor * src0 = op->src[0];
290-
const struct ggml_tensor * src1 = op->src[1];
291-
292-
return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) ||
293-
(op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
294-
op->src[1]->type == GGML_TYPE_F32 &&
295-
ggml_is_matrix(src0) &&
296-
ggml_is_matrix(src1) &&
297-
ggml_is_contiguous(src0) &&
298-
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
299-
300-
GGML_UNUSED(backend);
301-
}
302-
303-
static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
304-
return ggml_backend_buft_is_host(buft);
305-
306-
GGML_UNUSED(backend);
307-
}
308-
309264
static struct ggml_backend_i blas_backend_i = {
310-
/* .get_name = */ ggml_backend_blas_name,
265+
/* .get_name = */ ggml_backend_blas_get_name,
311266
/* .free = */ ggml_backend_blas_free,
312267
/* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
313268
/* .set_tensor_async = */ NULL,
@@ -319,8 +274,8 @@ static struct ggml_backend_i blas_backend_i = {
319274
/* .graph_plan_update = */ NULL,
320275
/* .graph_plan_compute = */ NULL,
321276
/* .graph_compute = */ ggml_backend_blas_graph_compute,
322-
/* .supports_op = */ ggml_backend_blas_supports_op,
323-
/* .supports_buft = */ ggml_backend_blas_supports_buft,
277+
/* .supports_op = */ NULL,
278+
/* .supports_buft = */ NULL,
324279
/* .offload_op = */ NULL,
325280
/* .event_record = */ NULL,
326281
/* .event_wait = */ NULL,
@@ -337,7 +292,7 @@ ggml_backend_t ggml_backend_blas_init(void) {
337292
ggml_backend_t backend = new ggml_backend {
338293
/* .guid = */ ggml_backend_blas_guid(),
339294
/* .interface = */ blas_backend_i,
340-
/* .device = */ nullptr,
295+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
341296
/* .context = */ ctx,
342297
};
343298

@@ -364,3 +319,202 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
364319
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
365320
ctx->n_threads = n_threads;
366321
}
322+
323+
// device interface
324+
325+
static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
326+
return "BLAS";
327+
328+
GGML_UNUSED(dev);
329+
}
330+
331+
static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
332+
#if defined(GGML_USE_ACCELERATE)
333+
return "Accelerate";
334+
#elif defined(GGML_BLAS_USE_MKL)
335+
return "MKL";
336+
#elif defined(GGML_BLAS_USE_BLIS)
337+
return "BLIS";
338+
#elif defined(GGML_BLAS_USE_NVPL)
339+
return "NVPL";
340+
#elif defined(OPENBLAS_VERSION)
341+
return "OpenBLAS";
342+
#else
343+
return "BLAS";
344+
#endif
345+
346+
GGML_UNUSED(dev);
347+
}
348+
349+
static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
350+
// TODO
351+
*free = 0;
352+
*total = 0;
353+
354+
GGML_UNUSED(dev);
355+
}
356+
357+
static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
358+
return GGML_BACKEND_DEVICE_TYPE_CPU;
359+
360+
GGML_UNUSED(dev);
361+
}
362+
363+
static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
364+
props->name = ggml_backend_blas_device_get_name(dev);
365+
props->description = ggml_backend_blas_device_get_description(dev);
366+
props->type = ggml_backend_blas_device_get_type(dev);
367+
ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
368+
props->caps = {
369+
/* async */ false,
370+
/* host_buffer */ false,
371+
/* events */ false,
372+
};
373+
}
374+
375+
static ggml_backend_t ggml_backend_blas_device_init(ggml_backend_dev_t dev, const char * params) {
376+
return ggml_backend_blas_init();
377+
378+
GGML_UNUSED(dev);
379+
GGML_UNUSED(params);
380+
}
381+
382+
static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
383+
return ggml_backend_cpu_buffer_type();
384+
385+
GGML_UNUSED(dev);
386+
}
387+
388+
static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
389+
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
390+
391+
GGML_UNUSED(dev);
392+
GGML_UNUSED(max_tensor_size);
393+
}
394+
395+
static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
396+
const struct ggml_tensor * src0 = op->src[0];
397+
const struct ggml_tensor * src1 = op->src[1];
398+
399+
switch (op->op) {
400+
case GGML_OP_NONE:
401+
case GGML_OP_RESHAPE:
402+
case GGML_OP_VIEW:
403+
case GGML_OP_PERMUTE:
404+
case GGML_OP_TRANSPOSE:
405+
return true;
406+
407+
case GGML_OP_MUL_MAT:
408+
{
409+
// BLAS usually is only faster for large matrices
410+
const struct ggml_tensor * src0 = op->src[0];
411+
const struct ggml_tensor * src1 = op->src[1];
412+
413+
const int64_t ne10 = src1->ne[0];
414+
415+
const int64_t ne0 = op->ne[0];
416+
const int64_t ne1 = op->ne[1];
417+
418+
// TODO: find the optimal value
419+
const int64_t min_batch = 32;
420+
421+
return (ggml_is_contiguous(src0) &&
422+
ggml_is_contiguous(src1) &&
423+
src1->type == GGML_TYPE_F32 &&
424+
(ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch));
425+
}
426+
427+
case GGML_OP_OUT_PROD:
428+
return (op->src[0]->type == GGML_TYPE_F32 &&
429+
op->src[1]->type == GGML_TYPE_F32 &&
430+
ggml_is_matrix(src0) &&
431+
ggml_is_matrix(src1) &&
432+
ggml_is_contiguous(src0) &&
433+
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
434+
435+
default:
436+
return false;
437+
438+
}
439+
440+
GGML_UNUSED(dev);
441+
}
442+
443+
static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
444+
return ggml_backend_buft_is_host(buft);
445+
446+
GGML_UNUSED(dev);
447+
}
448+
449+
static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
450+
/* .get_name = */ ggml_backend_blas_device_get_name,
451+
/* .get_description = */ ggml_backend_blas_device_get_description,
452+
/* .get_memory = */ ggml_backend_blas_device_get_memory,
453+
/* .get_type = */ ggml_backend_blas_device_get_type,
454+
/* .get_props = */ ggml_backend_blas_device_get_props,
455+
/* .init_backend = */ ggml_backend_blas_device_init,
456+
/* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
457+
/* .get_host_buffer_type = */ NULL,
458+
/* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_ptr,
459+
/* .supports_op = */ ggml_backend_blas_device_supports_op,
460+
/* .supports_buft = */ ggml_backend_blas_device_supports_buft,
461+
/* .offload_op = */ NULL,
462+
/* .event_new = */ NULL,
463+
/* .event_free = */ NULL,
464+
/* .event_synchronize = */ NULL,
465+
};
466+
467+
// backend reg interface
468+
469+
static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
470+
return "BLAS";
471+
472+
GGML_UNUSED(reg);
473+
}
474+
475+
static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
476+
return 1;
477+
478+
GGML_UNUSED(reg);
479+
}
480+
481+
static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
482+
GGML_ASSERT(index == 0);
483+
484+
static ggml_backend_device ggml_backend_blas_device = {
485+
/* .iface = */ ggml_backend_blas_device_i,
486+
/* .reg = */ reg,
487+
/* .context = */ nullptr,
488+
};
489+
490+
return &ggml_backend_blas_device;
491+
492+
GGML_UNUSED(reg);
493+
GGML_UNUSED(index);
494+
}
495+
496+
static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
497+
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
498+
return (void *)ggml_backend_blas_set_n_threads;
499+
}
500+
return NULL;
501+
502+
GGML_UNUSED(reg);
503+
GGML_UNUSED(name);
504+
}
505+
506+
static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
507+
/* .get_name = */ ggml_backend_blas_reg_get_name,
508+
/* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
509+
/* .get_device = */ ggml_backend_blas_reg_get_device,
510+
/* .get_proc_address = */ ggml_backend_blas_get_proc_address,
511+
};
512+
513+
ggml_backend_reg_t ggml_backend_blas_reg(void) {
514+
static struct ggml_backend_reg ggml_backend_blas_reg = {
515+
/* .iface = */ ggml_backend_blas_reg_i,
516+
/* .context = */ NULL,
517+
};
518+
519+
return &ggml_backend_blas_reg;
520+
}

0 commit comments

Comments
 (0)