4040
4141// ggml-backend interface
4242
43- std::vector<ggml_backend_buffer_type_t >& ggml_backend_cpu_get_extra_buffers_type () {
43+ std::vector<ggml_backend_buffer_type_t > & ggml_backend_cpu_get_extra_buffer_types () {
4444 static std::vector<ggml_backend_buffer_type_t > bufts = []() {
4545 std::vector<ggml_backend_buffer_type_t > bufts;
4646
@@ -62,23 +62,27 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
6262 }
6363#endif
6464
65- bufts.push_back (NULL );
66-
6765 return bufts;
6866 }();
6967
7068 return bufts;
7169}
7270
7371static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type (ggml_backend_dev_t device) {
74- return ggml_backend_cpu_get_extra_buffers_type ().data ();
72+ static std::vector<ggml_backend_buffer_type_t > extra_bufts = [] {
73+ std::vector<ggml_backend_buffer_type_t > bufts = ggml_backend_cpu_get_extra_buffer_types ();
74+ bufts.push_back (nullptr );
75+ return bufts;
76+ }();
77+
78+ return extra_bufts.data ();
7579
7680 GGML_UNUSED (device);
7781}
7882
7983static bool ggml_backend_cpu_is_extra_buffer_type (ggml_backend_buffer_type_t buft) {
80- for (auto * extra : ggml_backend_cpu_get_extra_buffers_type ()) {
81- if (extra && extra == buft) {
84+ for (auto * extra : ggml_backend_cpu_get_extra_buffer_types ()) {
85+ if (extra == buft) {
8286 return true ;
8387 }
8488 }
@@ -402,20 +406,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
402406 return true ;
403407 }
404408
405- // extra_buffer_op?
406- for (auto extra : ggml_backend_cpu_get_extra_buffers_type ()) {
407- if (extra) {
408- auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context ;
409- if (buf_extra && buf_extra->supports_op (dev, op)) {
410- return true ;
411- }
412- }
413- }
414-
415- // the other case need host buffer.
416- for (int i = 0 ; i < GGML_MAX_SRC; i++) {
417- if (op->src [i] && op->src [i]->buffer && !ggml_backend_buft_is_host (op->src [i]->buffer ->buft )) {
418- return false ;
409+ // check extra buffer types
410+ // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
411+ for (int i = 0 ; i < 4 ; i++) {
412+ if (op->src [i] && op->src [i]->buffer &&
413+ ggml_backend_cpu_is_extra_buffer_type (op->src [i]->buffer ->buft )) {
414+ auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src [i]->buffer ->buft ->context ;
415+ return buf_extra->supports_op (dev, op);
419416 }
420417 }
421418
0 commit comments