1818#endif
1919
2020struct ggml_backend_blas_context {
21- int n_threads = GGML_DEFAULT_N_THREADS ;
21+ int n_threads = std::thread::hardware_concurrency()/ 2 ;
2222 std::unique_ptr<char []> work_data;
2323 size_t work_size = 0 ;
2424#ifndef GGML_USE_OPENMP
2525 std::vector<std::future<void >> tasks;
2626#endif
2727};
2828
29- // helper function to determine if it is better to use BLAS or not
30- // for large matrices, BLAS is faster
31- static bool ggml_backend_blas_use_blas (const struct ggml_tensor * dst) {
32- const struct ggml_tensor * src0 = dst->src [0 ];
33- const struct ggml_tensor * src1 = dst->src [1 ];
34-
35- const int64_t ne10 = src1->ne [0 ];
36-
37- const int64_t ne0 = dst->ne [0 ];
38- const int64_t ne1 = dst->ne [1 ];
39-
40- // TODO: find the optimal values for these
41- if (ggml_is_contiguous (src0) &&
42- ggml_is_contiguous (src1) &&
43- src1->type == GGML_TYPE_F32 &&
44- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32 )) {
45-
46- /* printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
47- return true ;
48- }
49-
50- return false ;
51- }
52-
5329static void ggml_backend_blas_mul_mat (ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
5430 const struct ggml_tensor * src0 = dst->src [0 ];
5531 const struct ggml_tensor * src1 = dst->src [1 ];
@@ -235,7 +211,7 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
235211
236212// backend interface
237213
238- static const char * ggml_backend_blas_name (ggml_backend_t backend) {
214+ static const char * ggml_backend_blas_get_name (ggml_backend_t backend) {
239215 return " BLAS" ;
240216
241217 GGML_UNUSED (backend);
@@ -285,29 +261,8 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
285261 GGML_UNUSED (backend);
286262}
287263
288- static bool ggml_backend_blas_supports_op (ggml_backend_t backend, const struct ggml_tensor * op) {
289- const struct ggml_tensor * src0 = op->src [0 ];
290- const struct ggml_tensor * src1 = op->src [1 ];
291-
292- return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas (op)) ||
293- (op->op == GGML_OP_OUT_PROD && op->src [0 ]->type == GGML_TYPE_F32 &&
294- op->src [1 ]->type == GGML_TYPE_F32 &&
295- ggml_is_matrix (src0) &&
296- ggml_is_matrix (src1) &&
297- ggml_is_contiguous (src0) &&
298- (ggml_is_contiguous (src1) || ggml_is_transposed (src1)));
299-
300- GGML_UNUSED (backend);
301- }
302-
303- static bool ggml_backend_blas_supports_buft (ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
304- return ggml_backend_buft_is_host (buft);
305-
306- GGML_UNUSED (backend);
307- }
308-
309264static struct ggml_backend_i blas_backend_i = {
310- /* .get_name = */ ggml_backend_blas_name ,
265+ /* .get_name = */ ggml_backend_blas_get_name ,
311266 /* .free = */ ggml_backend_blas_free,
312267 /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
313268 /* .set_tensor_async = */ NULL ,
@@ -319,8 +274,8 @@ static struct ggml_backend_i blas_backend_i = {
319274 /* .graph_plan_update = */ NULL ,
320275 /* .graph_plan_compute = */ NULL ,
321276 /* .graph_compute = */ ggml_backend_blas_graph_compute,
322- /* .supports_op = */ ggml_backend_blas_supports_op ,
323- /* .supports_buft = */ ggml_backend_blas_supports_buft ,
277+ /* .supports_op = */ NULL ,
278+ /* .supports_buft = */ NULL ,
324279 /* .offload_op = */ NULL ,
325280 /* .event_record = */ NULL ,
326281 /* .event_wait = */ NULL ,
@@ -337,7 +292,7 @@ ggml_backend_t ggml_backend_blas_init(void) {
337292 ggml_backend_t backend = new ggml_backend {
338293 /* .guid = */ ggml_backend_blas_guid (),
339294 /* .interface = */ blas_backend_i,
340- /* .device = */ nullptr ,
295+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_blas_reg (), 0 ) ,
341296 /* .context = */ ctx,
342297 };
343298
@@ -364,3 +319,202 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
364319 ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context ;
365320 ctx->n_threads = n_threads;
366321}
322+
323+ // device interface
324+
325+ static const char * ggml_backend_blas_device_get_name (ggml_backend_dev_t dev) {
326+ return " BLAS" ;
327+
328+ GGML_UNUSED (dev);
329+ }
330+
331+ static const char * ggml_backend_blas_device_get_description (ggml_backend_dev_t dev) {
332+ #if defined(GGML_USE_ACCELERATE)
333+ return " Accelerate" ;
334+ #elif defined(GGML_BLAS_USE_MKL)
335+ return " MKL" ;
336+ #elif defined(GGML_BLAS_USE_BLIS)
337+ return " BLIS" ;
338+ #elif defined(GGML_BLAS_USE_NVPL)
339+ return " NVPL" ;
340+ #elif defined(OPENBLAS_VERSION)
341+ return " OpenBLAS" ;
342+ #else
343+ return " BLAS" ;
344+ #endif
345+
346+ GGML_UNUSED (dev);
347+ }
348+
349+ static void ggml_backend_blas_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
350+ // TODO
351+ *free = 0 ;
352+ *total = 0 ;
353+
354+ GGML_UNUSED (dev);
355+ }
356+
357+ static enum ggml_backend_dev_type ggml_backend_blas_device_get_type (ggml_backend_dev_t dev) {
358+ return GGML_BACKEND_DEVICE_TYPE_CPU;
359+
360+ GGML_UNUSED (dev);
361+ }
362+
363+ static void ggml_backend_blas_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
364+ props->name = ggml_backend_blas_device_get_name (dev);
365+ props->description = ggml_backend_blas_device_get_description (dev);
366+ props->type = ggml_backend_blas_device_get_type (dev);
367+ ggml_backend_blas_device_get_memory (dev, &props->memory_free , &props->memory_total );
368+ props->caps = {
369+ /* async */ false ,
370+ /* host_buffer */ false ,
371+ /* events */ false ,
372+ };
373+ }
374+
375+ static ggml_backend_t ggml_backend_blas_device_init (ggml_backend_dev_t dev, const char * params) {
376+ return ggml_backend_blas_init ();
377+
378+ GGML_UNUSED (dev);
379+ GGML_UNUSED (params);
380+ }
381+
382+ static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type (ggml_backend_dev_t dev) {
383+ return ggml_backend_cpu_buffer_type ();
384+
385+ GGML_UNUSED (dev);
386+ }
387+
388+ static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_ptr (ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
389+ return ggml_backend_cpu_buffer_from_ptr (ptr, size);
390+
391+ GGML_UNUSED (dev);
392+ GGML_UNUSED (max_tensor_size);
393+ }
394+
395+ static bool ggml_backend_blas_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
396+ const struct ggml_tensor * src0 = op->src [0 ];
397+ const struct ggml_tensor * src1 = op->src [1 ];
398+
399+ switch (op->op ) {
400+ case GGML_OP_NONE:
401+ case GGML_OP_RESHAPE:
402+ case GGML_OP_VIEW:
403+ case GGML_OP_PERMUTE:
404+ case GGML_OP_TRANSPOSE:
405+ return true ;
406+
407+ case GGML_OP_MUL_MAT:
408+ {
409+ // BLAS usually is only faster for large matrices
410+ const struct ggml_tensor * src0 = op->src [0 ];
411+ const struct ggml_tensor * src1 = op->src [1 ];
412+
413+ const int64_t ne10 = src1->ne [0 ];
414+
415+ const int64_t ne0 = op->ne [0 ];
416+ const int64_t ne1 = op->ne [1 ];
417+
418+ // TODO: find the optimal value
419+ const int64_t min_batch = 32 ;
420+
421+ return (ggml_is_contiguous (src0) &&
422+ ggml_is_contiguous (src1) &&
423+ src1->type == GGML_TYPE_F32 &&
424+ (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch));
425+ }
426+
427+ case GGML_OP_OUT_PROD:
428+ return (op->src [0 ]->type == GGML_TYPE_F32 &&
429+ op->src [1 ]->type == GGML_TYPE_F32 &&
430+ ggml_is_matrix (src0) &&
431+ ggml_is_matrix (src1) &&
432+ ggml_is_contiguous (src0) &&
433+ (ggml_is_contiguous (src1) || ggml_is_transposed (src1)));
434+
435+ default :
436+ return false ;
437+
438+ }
439+
440+ GGML_UNUSED (dev);
441+ }
442+
443+ static bool ggml_backend_blas_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
444+ return ggml_backend_buft_is_host (buft);
445+
446+ GGML_UNUSED (dev);
447+ }
448+
449+ static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
450+ /* .get_name = */ ggml_backend_blas_device_get_name,
451+ /* .get_description = */ ggml_backend_blas_device_get_description,
452+ /* .get_memory = */ ggml_backend_blas_device_get_memory,
453+ /* .get_type = */ ggml_backend_blas_device_get_type,
454+ /* .get_props = */ ggml_backend_blas_device_get_props,
455+ /* .init_backend = */ ggml_backend_blas_device_init,
456+ /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
457+ /* .get_host_buffer_type = */ NULL ,
458+ /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_ptr,
459+ /* .supports_op = */ ggml_backend_blas_device_supports_op,
460+ /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
461+ /* .offload_op = */ NULL ,
462+ /* .event_new = */ NULL ,
463+ /* .event_free = */ NULL ,
464+ /* .event_synchronize = */ NULL ,
465+ };
466+
467+ // backend reg interface
468+
469+ static const char * ggml_backend_blas_reg_get_name (ggml_backend_reg_t reg) {
470+ return " BLAS" ;
471+
472+ GGML_UNUSED (reg);
473+ }
474+
475+ static size_t ggml_backend_blas_reg_get_device_count (ggml_backend_reg_t reg) {
476+ return 1 ;
477+
478+ GGML_UNUSED (reg);
479+ }
480+
481+ static ggml_backend_dev_t ggml_backend_blas_reg_get_device (ggml_backend_reg_t reg, size_t index) {
482+ GGML_ASSERT (index == 0 );
483+
484+ static ggml_backend_device ggml_backend_blas_device = {
485+ /* .iface = */ ggml_backend_blas_device_i,
486+ /* .reg = */ reg,
487+ /* .context = */ nullptr ,
488+ };
489+
490+ return &ggml_backend_blas_device;
491+
492+ GGML_UNUSED (reg);
493+ GGML_UNUSED (index);
494+ }
495+
496+ static void * ggml_backend_blas_get_proc_address (ggml_backend_reg_t reg, const char * name) {
497+ if (strcmp (name, " ggml_backend_set_n_threads" ) == 0 ) {
498+ return (void *)ggml_backend_blas_set_n_threads;
499+ }
500+ return NULL ;
501+
502+ GGML_UNUSED (reg);
503+ GGML_UNUSED (name);
504+ }
505+
506+ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
507+ /* .get_name = */ ggml_backend_blas_reg_get_name,
508+ /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
509+ /* .get_device = */ ggml_backend_blas_reg_get_device,
510+ /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
511+ };
512+
513+ ggml_backend_reg_t ggml_backend_blas_reg (void ) {
514+ static struct ggml_backend_reg ggml_backend_blas_reg = {
515+ /* .iface = */ ggml_backend_blas_reg_i,
516+ /* .context = */ NULL ,
517+ };
518+
519+ return &ggml_backend_blas_reg;
520+ }
0 commit comments