diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ecfc1c856cb59..5e3dc0a3d0cc1 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -211,7 +211,7 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t) // ** backend sessions struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id) noexcept(false); + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); ~ggml_hexagon_session() noexcept(true); void allocate(int dev_id) noexcept(false); @@ -1631,10 +1631,13 @@ void ggml_hexagon_session::release() noexcept(true) { } } -ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) { +ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { buffer_type.context = nullptr; repack_buffer_type.context = nullptr; + buffer_type.device = dev; + repack_buffer_type.device = dev; + try { allocate(dev_id); @@ -3628,7 +3631,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { devices[i].iface = ggml_backend_hexagon_device_i; devices[i].reg = reg; try { - devices[i].context = new ggml_hexagon_session(i); + devices[i].context = new ggml_hexagon_session(i, &devices[i]); } catch (std::exception const &exc) { GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i); devices[i].context = nullptr;