3535
3636// ggml-backend interface
3737
38- std::vector<ggml_backend_buffer_type_t >& ggml_backend_cpu_get_extra_buffers_type () {
38+ std::vector<ggml_backend_buffer_type_t > & ggml_backend_cpu_get_extra_buffer_types () {
3939 static std::vector<ggml_backend_buffer_type_t > bufts = []() {
4040 std::vector<ggml_backend_buffer_type_t > bufts;
4141
@@ -57,23 +57,27 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
5757 }
5858#endif
5959
60- bufts.push_back (NULL );
61-
6260 return bufts;
6361 }();
6462
6563 return bufts;
6664}
6765
6866static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type (ggml_backend_dev_t device) {
69- return ggml_backend_cpu_get_extra_buffers_type ().data ();
67+ static std::vector<ggml_backend_buffer_type_t > extra_bufts = [] {
68+ std::vector<ggml_backend_buffer_type_t > bufts = ggml_backend_cpu_get_extra_buffer_types ();
69+ bufts.push_back (nullptr );
70+ return bufts;
71+ }();
72+
73+ return extra_bufts.data ();
7074
7175 GGML_UNUSED (device);
7276}
7377
7478static bool ggml_backend_cpu_is_extra_buffer_type (ggml_backend_buffer_type_t buft) {
75- for (auto * extra : ggml_backend_cpu_get_extra_buffers_type ()) {
76- if (extra && extra == buft) {
79+ for (auto * extra : ggml_backend_cpu_get_extra_buffer_types ()) {
80+ if (extra == buft) {
7781 return true ;
7882 }
7983 }
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
397401 return true ;
398402 }
399403
400- // extra_buffer_op?
401- for (auto extra : ggml_backend_cpu_get_extra_buffers_type ()) {
402- if (extra) {
403- auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context ;
404- if (buf_extra && buf_extra->supports_op (dev, op)) {
405- return true ;
406- }
407- }
408- }
409-
410- // the other case need host buffer.
411- for (int i = 0 ; i < GGML_MAX_SRC; i++) {
412- if (op->src [i] && op->src [i]->buffer && !ggml_backend_buft_is_host (op->src [i]->buffer ->buft )) {
413- return false ;
404+ // check extra buffer types
405+ // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
406+ for (int i = 0 ; i < 4 ; i++) {
407+ if (op->src [i] && op->src [i]->buffer &&
408+ ggml_backend_cpu_is_extra_buffer_type (op->src [i]->buffer ->buft )) {
409+ auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src [i]->buffer ->buft ->context ;
410+ return buf_extra->supports_op (dev, op);
414411 }
415412 }
416413
0 commit comments