@@ -10,7 +10,7 @@ uint32_t
1010backend_buffer_type_get_name (struct vn_cs_encoder *enc, struct vn_cs_decoder *dec, struct virgl_apir_context *ctx) {
1111 UNUSED (ctx);
1212 ggml_backend_buffer_type_t buft;
13- buft = vn_decode_ggml_buft (dec);
13+ buft = vn_decode_ggml_buffer_type (dec);
1414
1515 const char *string = buft->iface .get_name (buft);
1616
@@ -25,7 +25,7 @@ uint32_t
2525backend_buffer_type_get_alignment (struct vn_cs_encoder *enc, struct vn_cs_decoder *dec, struct virgl_apir_context *ctx) {
2626 UNUSED (ctx);
2727 ggml_backend_buffer_type_t buft;
28- buft = vn_decode_ggml_buft (dec);
28+ buft = vn_decode_ggml_buffer_type (dec);
2929
3030 size_t value = buft->iface .get_alignment (buft);
3131 vn_encode_size_t (enc, &value);
@@ -37,7 +37,7 @@ uint32_t
3737backend_buffer_type_get_max_size (struct vn_cs_encoder *enc, struct vn_cs_decoder *dec, struct virgl_apir_context *ctx) {
3838 UNUSED (ctx);
3939 ggml_backend_buffer_type_t buft;
40- buft = vn_decode_ggml_buft (dec);
40+ buft = vn_decode_ggml_buffer_type (dec);
4141
4242 size_t value = buft->iface .get_max_size (buft);
4343 vn_encode_size_t (enc, &value);
@@ -49,7 +49,7 @@ uint32_t
4949backend_buffer_type_is_host (struct vn_cs_encoder *enc, struct vn_cs_decoder *dec, struct virgl_apir_context *ctx) {
5050 UNUSED (ctx);
5151 ggml_backend_buffer_type_t buft;
52- buft = vn_decode_ggml_buft (dec);
52+ buft = vn_decode_ggml_buffer_type (dec);
5353
5454 bool is_host = buft->iface .is_host (buft);
5555 vn_encode_bool_t (enc, &is_host);
@@ -60,15 +60,32 @@ backend_buffer_type_is_host(struct vn_cs_encoder *enc, struct vn_cs_decoder *dec
6060uint32_t
6161backend_buffer_type_alloc_buffer (struct vn_cs_encoder *enc, struct vn_cs_decoder *dec, struct virgl_apir_context *ctx) {
6262 UNUSED (ctx);
63- ggml_backend_buffer_type_t buft;
64- buft = vn_decode_ggml_buft (dec);
63+ #if APIR_ALLOC_FROM_HOST_PTR
64+ uint32_t shmem_res_id;
65+ vn_decode_virtgpu_shmem_res_id (dec, &shmem_res_id);
6566
67+ void *shmem_data = ctx->iface .get_shmem_ptr (ctx->virgl_ctx , shmem_res_id);
68+ if (!shmem_data) {
69+ FATAL (" Couldn't get the shmem addr from virgl :/" );
70+ }
71+ #else
72+ ggml_backend_buffer_type_t buft;
73+ buft = vn_decode_ggml_buffer_type (dec);
74+ #endif
6675 size_t size;
6776 vn_decode_size_t (dec, &size);
6877
69- ggml_backend_buffer_t buffer = buft->iface .alloc_buffer (buft, size);
70- apir_buffer_handle_t *buffer_handle = (apir_buffer_handle_t *) buffer;
71- vn_encode_ggml_buffer_handle (enc, buffer_handle);
78+ ggml_backend_buffer_t buffer;
79+ #if APIR_ALLOC_FROM_HOST_PTR
80+ #define MAX_TENSOR_SIZE 323205120
81+ buffer = dev->iface .buffer_from_host_ptr (dev, shmem_data, size, MAX_TENSOR_SIZE);
82+
83+ vn_encode_ggml_buffer_type (enc, buffer->buft );
84+ #else
85+ buffer = buft->iface .alloc_buffer (buft, size);
86+ #endif
87+
88+ vn_encode_ggml_buffer (enc, buffer);
7289
7390 if (buffer) {
7491 track_backend_buffer (buffer);
0 commit comments