@@ -1941,7 +1941,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19411941        if  (device->fp16 ) {
19421942            device_extensions.push_back (" VK_KHR_shader_float16_int8"  );
19431943        }
1944-         device->name  = device-> properties . deviceName . data ( );
1944+         device->name  = GGML_VK_NAME +  std::to_string (idx );
19451945
19461946        device_create_info = {
19471947            vk::DeviceCreateFlags (),
@@ -1968,7 +1968,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19681968
19691969        device->buffer_type  = {
19701970            /*  .iface    = */   ggml_backend_vk_buffer_type_interface,
1971-             /*  .device   = */   nullptr ,
1971+             /*  .device   = */   ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), idx) ,
19721972            /*  .context  = */   new  ggml_backend_vk_buffer_type_context{ device->name , device },
19731973        };
19741974
@@ -6378,7 +6378,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
63786378            /*  .get_alloc_size   = */   ggml_backend_cpu_buffer_type ()->iface .get_alloc_size ,
63796379            /*  .is_host          = */   ggml_backend_cpu_buffer_type ()->iface .is_host ,
63806380        },
6381-         /*  .device   = */   nullptr ,
6381+         /*  .device   = */   ggml_backend_reg_dev_get ( ggml_backend_vk_reg (),  0 ) ,
63826382        /*  .context  = */   nullptr ,
63836383    };
63846384
@@ -6581,9 +6581,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
65816581    UNUSED (backend);
65826582}
65836583
6584- static  bool  ggml_backend_vk_supports_op (ggml_backend_t  backend, const  ggml_tensor * op) {
6585-     //  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
6584+ //  TODO: enable async and synchronize
6585+ static  ggml_backend_i ggml_backend_vk_interface = {
6586+     /*  .get_name                = */   ggml_backend_vk_name,
6587+     /*  .free                    = */   ggml_backend_vk_free,
6588+     /*  .get_default_buffer_type = */   ggml_backend_vk_get_default_buffer_type,
6589+     /*  .set_tensor_async        = */   NULL ,  //  ggml_backend_vk_set_tensor_async,
6590+     /*  .get_tensor_async        = */   NULL ,  //  ggml_backend_vk_get_tensor_async,
6591+     /*  .cpy_tensor_async        = */   NULL ,  //  ggml_backend_vk_cpy_tensor_async,
6592+     /*  .synchronize             = */   NULL ,  //  ggml_backend_vk_synchronize,
6593+     /*  .graph_plan_create       = */   NULL ,
6594+     /*  .graph_plan_free         = */   NULL ,
6595+     /*  .graph_plan_update       = */   NULL ,
6596+     /*  .graph_plan_compute      = */   NULL ,
6597+     /*  .graph_compute           = */   ggml_backend_vk_graph_compute,
6598+     /*  .supports_op             = */   NULL ,
6599+     /*  .supports_buft           = */   NULL ,
6600+     /*  .offload_op              = */   NULL ,
6601+     /*  .event_record            = */   NULL ,
6602+     /*  .event_wait              = */   NULL ,
6603+ };
6604+ 
6605+ static  ggml_guid_t  ggml_backend_vk_guid () {
6606+     static  ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b  };
6607+     return  &guid;
6608+ }
6609+ 
6610+ ggml_backend_t  ggml_backend_vk_init (size_t  dev_num) {
6611+     VK_LOG_DEBUG (" ggml_backend_vk_init("   << dev_num << " )"  );
6612+ 
6613+     ggml_backend_vk_context * ctx = new  ggml_backend_vk_context;
6614+     ggml_vk_init (ctx, dev_num);
6615+ 
6616+     ggml_backend_t  vk_backend = new  ggml_backend {
6617+         /*  .guid      = */   ggml_backend_vk_guid (),
6618+         /*  .interface = */   ggml_backend_vk_interface,
6619+         /*  .device    = */   ggml_backend_reg_dev_get (ggml_backend_vk_reg (), dev_num),
6620+         /*  .context   = */   ctx,
6621+     };
6622+ 
6623+     return  vk_backend;
6624+ }
6625+ 
6626+ bool  ggml_backend_is_vk (ggml_backend_t  backend) {
6627+     return  backend != NULL  && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6628+ }
6629+ 
6630+ int  ggml_backend_vk_get_device_count () {
6631+     return  ggml_vk_get_device_count ();
6632+ }
6633+ 
6634+ void  ggml_backend_vk_get_device_description (int  device, char  * description, size_t  description_size) {
6635+     GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6636+     int  dev_idx = vk_instance.device_indices [device];
6637+     ggml_vk_get_device_description (dev_idx, description, description_size);
6638+ }
6639+ 
6640+ void  ggml_backend_vk_get_device_memory (int  device, size_t  * free, size_t  * total) {
6641+     GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6642+ 
6643+     vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6644+ 
6645+     vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6646+ 
6647+     for  (const  vk::MemoryHeap& heap : memprops.memoryHeaps ) {
6648+         if  (heap.flags  & vk::MemoryHeapFlagBits::eDeviceLocal) {
6649+             *total = heap.size ;
6650+             *free = heap.size ;
6651+             break ;
6652+         }
6653+     }
6654+ }
6655+ 
6656+ // ////////////////////////
6657+ 
6658+ struct  ggml_backend_vk_device_context  {
6659+     int  device;
6660+     std::string name;
6661+     std::string description;
6662+ };
6663+ 
6664+ static  const  char  * ggml_backend_vk_device_get_name (ggml_backend_dev_t  dev) {
6665+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6666+     return  ctx->name .c_str ();
6667+ }
6668+ 
6669+ static  const  char  * ggml_backend_vk_device_get_description (ggml_backend_dev_t  dev) {
6670+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6671+     return  ctx->description .c_str ();
6672+ }
6673+ 
6674+ static  void  ggml_backend_vk_device_get_memory (ggml_backend_dev_t  device, size_t  * free, size_t  * total) {
6675+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context ;
6676+     ggml_backend_vk_get_device_memory (ctx->device , free, total);
6677+ }
6678+ 
6679+ static  ggml_backend_buffer_type_t  ggml_backend_vk_device_get_buffer_type (ggml_backend_dev_t  dev) {
6680+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6681+     return  ggml_backend_vk_buffer_type (ctx->device );
6682+ }
6683+ 
6684+ static  ggml_backend_buffer_type_t  ggml_backend_vk_device_get_host_buffer_type (ggml_backend_dev_t  dev) {
6685+     UNUSED (dev);
6686+     return  ggml_backend_vk_host_buffer_type ();
6687+ }
65866688
6689+ static  enum  ggml_backend_dev_type ggml_backend_vk_device_get_type (ggml_backend_dev_t  dev) {
6690+     UNUSED (dev);
6691+     return  GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6692+ }
6693+ 
6694+ static  void  ggml_backend_vk_device_get_props (ggml_backend_dev_t  dev, struct  ggml_backend_dev_props  * props) {
6695+     props->name         = ggml_backend_vk_device_get_name (dev);
6696+     props->description  = ggml_backend_vk_device_get_description (dev);
6697+     props->type         = ggml_backend_vk_device_get_type (dev);
6698+     ggml_backend_vk_device_get_memory (dev, &props->memory_free , &props->memory_total );
6699+     props->caps  = {
6700+         /*  async       */   false ,
6701+         /*  host_buffer */   true ,
6702+         /*  events      */   false ,
6703+     };
6704+ }
6705+ 
6706+ static  ggml_backend_t  ggml_backend_vk_device_init (ggml_backend_dev_t  dev, const  char  * params) {
6707+     UNUSED (params);
6708+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6709+     return  ggml_backend_vk_init (ctx->device );
6710+ }
6711+ 
6712+ static  bool  ggml_backend_vk_device_supports_op (ggml_backend_dev_t  dev, const  ggml_tensor * op) {
65876713    switch  (op->op ) {
65886714        case  GGML_OP_UNARY:
65896715            switch  (ggml_get_unary_op (op)) {
@@ -6701,97 +6827,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
67016827            return  false ;
67026828    }
67036829
6704-     UNUSED (backend);
6705- }
6706- 
6707- static  bool  ggml_backend_vk_offload_op (ggml_backend_t  backend, const  ggml_tensor * op) {
6708-     const  int  min_batch_size = 32 ;
6709- 
6710-     return  (op->ne [1 ] >= min_batch_size && op->op  != GGML_OP_GET_ROWS) ||
6711-            (op->ne [2 ] >= min_batch_size && op->op  == GGML_OP_MUL_MAT_ID);
6712- 
6713-     UNUSED (backend);
6830+     UNUSED (dev);
67146831}
67156832
6716- static  bool  ggml_backend_vk_supports_buft ( ggml_backend_t  backend , ggml_backend_buffer_type_t  buft) {
6833+ static  bool  ggml_backend_vk_device_supports_buft ( ggml_backend_dev_t  dev , ggml_backend_buffer_type_t  buft) {
67176834    if  (buft->iface .get_name  != ggml_backend_vk_buffer_type_name) {
67186835        return  false ;
67196836    }
67206837
6838+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
67216839    ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context ;
6722-     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
6723- 
6724-     return  buft_ctx->device  == ctx->device ;
6725- }
6726- 
6727- //  TODO: enable async and synchronize
6728- static  ggml_backend_i ggml_backend_vk_interface = {
6729-     /*  .get_name                = */   ggml_backend_vk_name,
6730-     /*  .free                    = */   ggml_backend_vk_free,
6731-     /*  .get_default_buffer_type = */   ggml_backend_vk_get_default_buffer_type,
6732-     /*  .set_tensor_async        = */   NULL ,  //  ggml_backend_vk_set_tensor_async,
6733-     /*  .get_tensor_async        = */   NULL ,  //  ggml_backend_vk_get_tensor_async,
6734-     /*  .cpy_tensor_async        = */   NULL ,  //  ggml_backend_vk_cpy_tensor_async,
6735-     /*  .synchronize             = */   NULL ,  //  ggml_backend_vk_synchronize,
6736-     /*  .graph_plan_create       = */   NULL ,
6737-     /*  .graph_plan_free         = */   NULL ,
6738-     /*  .graph_plan_update       = */   NULL ,
6739-     /*  .graph_plan_compute      = */   NULL ,
6740-     /*  .graph_compute           = */   ggml_backend_vk_graph_compute,
6741-     /*  .supports_op             = */   ggml_backend_vk_supports_op,
6742-     /*  .supports_buft           = */   ggml_backend_vk_supports_buft,
6743-     /*  .offload_op              = */   ggml_backend_vk_offload_op,
6744-     /*  .event_record            = */   NULL ,
6745-     /*  .event_wait              = */   NULL ,
6746- };
67476840
6748- static  ggml_guid_t  ggml_backend_vk_guid () {
6749-     static  ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b  };
6750-     return  &guid;
6841+     return  buft_ctx->device ->idx  == ctx->device ;
67516842}
67526843
6753- ggml_backend_t   ggml_backend_vk_init ( size_t  dev_num ) {
6754-     VK_LOG_DEBUG ( " ggml_backend_vk_init( "  << dev_num <<  " ) " ) ;
6844+ static   bool   ggml_backend_vk_device_offload_op ( ggml_backend_dev_t  dev,  const  ggml_tensor * op ) {
6845+     const   int  min_batch_size =  32 ;
67556846
6756-     ggml_backend_vk_context * ctx =  new  ggml_backend_vk_context; 
6757-     ggml_vk_init (ctx, dev_num );
6847+     return  (op-> ne [ 1 ] >= min_batch_size && op-> op  != GGML_OP_GET_ROWS) || 
6848+            (op-> ne [ 2 ] >= min_batch_size && op-> op  == GGML_OP_MUL_MAT_ID );
67586849
6759-     ggml_backend_t  vk_backend = new  ggml_backend {
6760-         /*  .guid      = */   ggml_backend_vk_guid (),
6761-         /*  .interface = */   ggml_backend_vk_interface,
6762-         /*  .device    = */   nullptr ,
6763-         /*  .context   = */   ctx,
6764-     };
6850+     UNUSED (dev);
6851+ }
6852+ 
6853+ static  const  struct  ggml_backend_device_i  ggml_backend_vk_device_i = {
6854+     /*  .get_name             = */   ggml_backend_vk_device_get_name,
6855+     /*  .get_description      = */   ggml_backend_vk_device_get_description,
6856+     /*  .get_memory           = */   ggml_backend_vk_device_get_memory,
6857+     /*  .get_type             = */   ggml_backend_vk_device_get_type,
6858+     /*  .get_props            = */   ggml_backend_vk_device_get_props,
6859+     /*  .init_backend         = */   ggml_backend_vk_device_init,
6860+     /*  .get_buffer_type      = */   ggml_backend_vk_device_get_buffer_type,
6861+     /*  .get_host_buffer_type = */   ggml_backend_vk_device_get_host_buffer_type,
6862+     /*  .buffer_from_host_ptr = */   NULL ,
6863+     /*  .supports_op          = */   ggml_backend_vk_device_supports_op,
6864+     /*  .supports_buft        = */   ggml_backend_vk_device_supports_buft,
6865+     /*  .offload_op           = */   ggml_backend_vk_device_offload_op,
6866+     /*  .event_new            = */   NULL ,
6867+     /*  .event_free           = */   NULL ,
6868+     /*  .event_synchronize    = */   NULL ,
6869+ };
67656870
6766-     return  vk_backend;
6871+ static  const  char  * ggml_backend_vk_reg_get_name (ggml_backend_reg_t  reg) {
6872+     UNUSED (reg);
6873+     return  GGML_VK_NAME;
67676874}
67686875
6769- bool  ggml_backend_is_vk (ggml_backend_t  backend) {
6770-     return  backend != NULL  && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6876+ static  size_t  ggml_backend_vk_reg_get_device_count (ggml_backend_reg_t  reg) {
6877+     UNUSED (reg);
6878+     return  ggml_backend_vk_get_device_count ();
67716879}
67726880
6773- int  ggml_backend_vk_get_device_count () {
6774-     return  ggml_vk_get_device_count ();
6775- }
6881+ static  ggml_backend_dev_t  ggml_backend_vk_reg_get_device (ggml_backend_reg_t  reg, size_t  device) {
6882+     static  std::vector<ggml_backend_dev_t > devices;
67766883
6777- void  ggml_backend_vk_get_device_description (int  device, char  * description, size_t  description_size) {
6778-     ggml_vk_get_device_description (device, description, description_size);
6779- }
6884+     static  bool  initialized = false ;
67806885
6781- void  ggml_backend_vk_get_device_memory (int  device, size_t  * free, size_t  * total) {
6782-     GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6886+     {
6887+         static  std::mutex mutex;
6888+         std::lock_guard<std::mutex> lock (mutex);
6889+         if  (!initialized) {
6890+             for  (size_t  i = 0 ; i < ggml_backend_vk_get_device_count (); i++) {
6891+                 ggml_backend_vk_device_context * ctx = new  ggml_backend_vk_device_context;
6892+                 char  desc[256 ];
6893+                 ggml_backend_vk_get_device_description (i, desc, sizeof (desc));
6894+                 ctx->device  = i;
6895+                 ctx->name  = GGML_VK_NAME + std::to_string (i);
6896+                 ctx->description  = desc;
6897+                 devices.push_back (new  ggml_backend_device {
6898+                     /*  .iface   = */   ggml_backend_vk_device_i,
6899+                     /*  .reg     = */   reg,
6900+                     /*  .context = */   ctx,
6901+                 });
6902+             }
6903+             initialized = true ;
6904+         }
6905+     }
67836906
6784-     vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6907+     GGML_ASSERT (device < devices.size ());
6908+     return  devices[device];
6909+ }
67856910
6786-     vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6911+ static  const  struct  ggml_backend_reg_i  ggml_backend_vk_reg_i = {
6912+     /*  .get_name         = */   ggml_backend_vk_reg_get_name,
6913+     /*  .get_device_count = */   ggml_backend_vk_reg_get_device_count,
6914+     /*  .get_device       = */   ggml_backend_vk_reg_get_device,
6915+     /*  .get_proc_address = */   NULL ,
6916+ };
67876917
6788-      for  ( const  vk::MemoryHeap& heap : memprops. memoryHeaps ) {
6789-          if  (heap. flags  & vk::MemoryHeapFlagBits::eDeviceLocal)  {
6790-              *total = heap. size ; 
6791-             *free = heap. size ; 
6792-              break ;
6793-         } 
6794-     } 
6918+ ggml_backend_reg_t   ggml_backend_vk_reg ( ) {
6919+     static  ggml_backend_reg reg =  {
6920+         /*  .iface    =  */  ggml_backend_vk_reg_i, 
6921+         /*  .context =  */   nullptr , 
6922+     } ;
6923+ 
6924+     return  ® 
67956925}
67966926
67976927//  Extension availability
0 commit comments