@@ -1941,7 +1941,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19411941 if (device->fp16 ) {
19421942 device_extensions.push_back (" VK_KHR_shader_float16_int8" );
19431943 }
1944- device->name = device-> properties . deviceName . data ( );
1944+ device->name = GGML_VK_NAME + std::to_string (idx );
19451945
19461946 device_create_info = {
19471947 vk::DeviceCreateFlags (),
@@ -1968,7 +1968,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19681968
19691969 device->buffer_type = {
19701970 /* .iface = */ ggml_backend_vk_buffer_type_interface,
1971- /* .device = */ nullptr ,
1971+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), idx) ,
19721972 /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name , device },
19731973 };
19741974
@@ -6378,7 +6378,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
63786378 /* .get_alloc_size = */ ggml_backend_cpu_buffer_type ()->iface .get_alloc_size ,
63796379 /* .is_host = */ ggml_backend_cpu_buffer_type ()->iface .is_host ,
63806380 },
6381- /* .device = */ nullptr ,
6381+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), 0 ) ,
63826382 /* .context = */ nullptr ,
63836383 };
63846384
@@ -6581,9 +6581,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
65816581 UNUSED (backend);
65826582}
65836583
6584- static bool ggml_backend_vk_supports_op (ggml_backend_t backend, const ggml_tensor * op) {
6585- // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
6584+ // TODO: enable async and synchronize
6585+ static ggml_backend_i ggml_backend_vk_interface = {
6586+ /* .get_name = */ ggml_backend_vk_name,
6587+ /* .free = */ ggml_backend_vk_free,
6588+ /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6589+ /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6590+ /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6591+ /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6592+ /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6593+ /* .graph_plan_create = */ NULL ,
6594+ /* .graph_plan_free = */ NULL ,
6595+ /* .graph_plan_update = */ NULL ,
6596+ /* .graph_plan_compute = */ NULL ,
6597+ /* .graph_compute = */ ggml_backend_vk_graph_compute,
6598+ /* .supports_op = */ NULL ,
6599+ /* .supports_buft = */ NULL ,
6600+ /* .offload_op = */ NULL ,
6601+ /* .event_record = */ NULL ,
6602+ /* .event_wait = */ NULL ,
6603+ };
6604+
6605+ static ggml_guid_t ggml_backend_vk_guid () {
6606+ static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6607+ return &guid;
6608+ }
6609+
6610+ ggml_backend_t ggml_backend_vk_init (size_t dev_num) {
6611+ VK_LOG_DEBUG (" ggml_backend_vk_init(" << dev_num << " )" );
6612+
6613+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6614+ ggml_vk_init (ctx, dev_num);
6615+
6616+ ggml_backend_t vk_backend = new ggml_backend {
6617+ /* .guid = */ ggml_backend_vk_guid (),
6618+ /* .interface = */ ggml_backend_vk_interface,
6619+ /* .device = */ ggml_backend_reg_dev_get (ggml_backend_vk_reg (), dev_num),
6620+ /* .context = */ ctx,
6621+ };
6622+
6623+ return vk_backend;
6624+ }
6625+
6626+ bool ggml_backend_is_vk (ggml_backend_t backend) {
6627+ return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6628+ }
6629+
6630+ int ggml_backend_vk_get_device_count () {
6631+ return ggml_vk_get_device_count ();
6632+ }
6633+
6634+ void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6635+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6636+ int dev_idx = vk_instance.device_indices [device];
6637+ ggml_vk_get_device_description (dev_idx, description, description_size);
6638+ }
6639+
6640+ void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6641+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6642+
6643+ vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6644+
6645+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6646+
6647+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps ) {
6648+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6649+ *total = heap.size ;
6650+ *free = heap.size ;
6651+ break ;
6652+ }
6653+ }
6654+ }
6655+
6656+ // ////////////////////////
6657+
6658+ struct ggml_backend_vk_device_context {
6659+ int device;
6660+ std::string name;
6661+ std::string description;
6662+ };
6663+
6664+ static const char * ggml_backend_vk_device_get_name (ggml_backend_dev_t dev) {
6665+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6666+ return ctx->name .c_str ();
6667+ }
6668+
6669+ static const char * ggml_backend_vk_device_get_description (ggml_backend_dev_t dev) {
6670+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6671+ return ctx->description .c_str ();
6672+ }
6673+
6674+ static void ggml_backend_vk_device_get_memory (ggml_backend_dev_t device, size_t * free, size_t * total) {
6675+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context ;
6676+ ggml_backend_vk_get_device_memory (ctx->device , free, total);
6677+ }
6678+
6679+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type (ggml_backend_dev_t dev) {
6680+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6681+ return ggml_backend_vk_buffer_type (ctx->device );
6682+ }
6683+
6684+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type (ggml_backend_dev_t dev) {
6685+ UNUSED (dev);
6686+ return ggml_backend_vk_host_buffer_type ();
6687+ }
65866688
6689+ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type (ggml_backend_dev_t dev) {
6690+ UNUSED (dev);
6691+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6692+ }
6693+
6694+ static void ggml_backend_vk_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
6695+ props->name = ggml_backend_vk_device_get_name (dev);
6696+ props->description = ggml_backend_vk_device_get_description (dev);
6697+ props->type = ggml_backend_vk_device_get_type (dev);
6698+ ggml_backend_vk_device_get_memory (dev, &props->memory_free , &props->memory_total );
6699+ props->caps = {
6700+ /* async */ false ,
6701+ /* host_buffer */ true ,
6702+ /* events */ false ,
6703+ };
6704+ }
6705+
6706+ static ggml_backend_t ggml_backend_vk_device_init (ggml_backend_dev_t dev, const char * params) {
6707+ UNUSED (params);
6708+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6709+ return ggml_backend_vk_init (ctx->device );
6710+ }
6711+
6712+ static bool ggml_backend_vk_device_supports_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
65876713 switch (op->op ) {
65886714 case GGML_OP_UNARY:
65896715 switch (ggml_get_unary_op (op)) {
@@ -6701,97 +6827,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
67016827 return false ;
67026828 }
67036829
6704- UNUSED (backend);
6705- }
6706-
6707- static bool ggml_backend_vk_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
6708- const int min_batch_size = 32 ;
6709-
6710- return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6711- (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6712-
6713- UNUSED (backend);
6830+ UNUSED (dev);
67146831}
67156832
6716- static bool ggml_backend_vk_supports_buft ( ggml_backend_t backend , ggml_backend_buffer_type_t buft) {
6833+ static bool ggml_backend_vk_device_supports_buft ( ggml_backend_dev_t dev , ggml_backend_buffer_type_t buft) {
67176834 if (buft->iface .get_name != ggml_backend_vk_buffer_type_name) {
67186835 return false ;
67196836 }
67206837
6838+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
67216839 ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context ;
6722- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
6723-
6724- return buft_ctx->device == ctx->device ;
6725- }
6726-
6727- // TODO: enable async and synchronize
6728- static ggml_backend_i ggml_backend_vk_interface = {
6729- /* .get_name = */ ggml_backend_vk_name,
6730- /* .free = */ ggml_backend_vk_free,
6731- /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6732- /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6733- /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6734- /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6735- /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6736- /* .graph_plan_create = */ NULL ,
6737- /* .graph_plan_free = */ NULL ,
6738- /* .graph_plan_update = */ NULL ,
6739- /* .graph_plan_compute = */ NULL ,
6740- /* .graph_compute = */ ggml_backend_vk_graph_compute,
6741- /* .supports_op = */ ggml_backend_vk_supports_op,
6742- /* .supports_buft = */ ggml_backend_vk_supports_buft,
6743- /* .offload_op = */ ggml_backend_vk_offload_op,
6744- /* .event_record = */ NULL ,
6745- /* .event_wait = */ NULL ,
6746- };
67476840
6748- static ggml_guid_t ggml_backend_vk_guid () {
6749- static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6750- return &guid;
6841+ return buft_ctx->device ->idx == ctx->device ;
67516842}
67526843
6753- ggml_backend_t ggml_backend_vk_init ( size_t dev_num ) {
6754- VK_LOG_DEBUG ( " ggml_backend_vk_init( " << dev_num << " ) " ) ;
6844+ static bool ggml_backend_vk_device_offload_op ( ggml_backend_dev_t dev, const ggml_tensor * op ) {
6845+ const int min_batch_size = 32 ;
67556846
6756- ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6757- ggml_vk_init (ctx, dev_num );
6847+ return (op-> ne [ 1 ] >= min_batch_size && op-> op != GGML_OP_GET_ROWS) ||
6848+ (op-> ne [ 2 ] >= min_batch_size && op-> op == GGML_OP_MUL_MAT_ID );
67586849
6759- ggml_backend_t vk_backend = new ggml_backend {
6760- /* .guid = */ ggml_backend_vk_guid (),
6761- /* .interface = */ ggml_backend_vk_interface,
6762- /* .device = */ nullptr ,
6763- /* .context = */ ctx,
6764- };
6850+ UNUSED (dev);
6851+ }
6852+
6853+ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
6854+ /* .get_name = */ ggml_backend_vk_device_get_name,
6855+ /* .get_description = */ ggml_backend_vk_device_get_description,
6856+ /* .get_memory = */ ggml_backend_vk_device_get_memory,
6857+ /* .get_type = */ ggml_backend_vk_device_get_type,
6858+ /* .get_props = */ ggml_backend_vk_device_get_props,
6859+ /* .init_backend = */ ggml_backend_vk_device_init,
6860+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
6861+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
6862+ /* .buffer_from_host_ptr = */ NULL ,
6863+ /* .supports_op = */ ggml_backend_vk_device_supports_op,
6864+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
6865+ /* .offload_op = */ ggml_backend_vk_device_offload_op,
6866+ /* .event_new = */ NULL ,
6867+ /* .event_free = */ NULL ,
6868+ /* .event_synchronize = */ NULL ,
6869+ };
67656870
6766- return vk_backend;
6871+ static const char * ggml_backend_vk_reg_get_name (ggml_backend_reg_t reg) {
6872+ UNUSED (reg);
6873+ return GGML_VK_NAME;
67676874}
67686875
6769- bool ggml_backend_is_vk (ggml_backend_t backend) {
6770- return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6876+ static size_t ggml_backend_vk_reg_get_device_count (ggml_backend_reg_t reg) {
6877+ UNUSED (reg);
6878+ return ggml_backend_vk_get_device_count ();
67716879}
67726880
6773- int ggml_backend_vk_get_device_count () {
6774- return ggml_vk_get_device_count ();
6775- }
6881+ static ggml_backend_dev_t ggml_backend_vk_reg_get_device (ggml_backend_reg_t reg, size_t device) {
6882+ static std::vector<ggml_backend_dev_t > devices;
67766883
6777- void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6778- ggml_vk_get_device_description (device, description, description_size);
6779- }
6884+ static bool initialized = false ;
67806885
6781- void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6782- GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6886+ {
6887+ static std::mutex mutex;
6888+ std::lock_guard<std::mutex> lock (mutex);
6889+ if (!initialized) {
6890+ for (size_t i = 0 ; i < ggml_backend_vk_get_device_count (); i++) {
6891+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
6892+ char desc[256 ];
6893+ ggml_backend_vk_get_device_description (i, desc, sizeof (desc));
6894+ ctx->device = i;
6895+ ctx->name = GGML_VK_NAME + std::to_string (i);
6896+ ctx->description = desc;
6897+ devices.push_back (new ggml_backend_device {
6898+ /* .iface = */ ggml_backend_vk_device_i,
6899+ /* .reg = */ reg,
6900+ /* .context = */ ctx,
6901+ });
6902+ }
6903+ initialized = true ;
6904+ }
6905+ }
67836906
6784- vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6907+ GGML_ASSERT (device < devices.size ());
6908+ return devices[device];
6909+ }
67856910
6786- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6911+ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6912+ /* .get_name = */ ggml_backend_vk_reg_get_name,
6913+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
6914+ /* .get_device = */ ggml_backend_vk_reg_get_device,
6915+ /* .get_proc_address = */ NULL ,
6916+ };
67876917
6788- for ( const vk::MemoryHeap& heap : memprops. memoryHeaps ) {
6789- if (heap. flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6790- *total = heap. size ;
6791- *free = heap. size ;
6792- break ;
6793- }
6794- }
6918+ ggml_backend_reg_t ggml_backend_vk_reg ( ) {
6919+ static ggml_backend_reg reg = {
6920+ /* .iface = */ ggml_backend_vk_reg_i,
6921+ /* .context = */ nullptr ,
6922+ } ;
6923+
6924+ return ®
67956925}
67966926
67976927// Extension availability
0 commit comments