@@ -1926,7 +1926,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19261926 if (device->fp16 ) {
19271927 device_extensions.push_back (" VK_KHR_shader_float16_int8" );
19281928 }
1929- device->name = device-> properties . deviceName . data ( );
1929+ device->name = GGML_VK_NAME + std::to_string (idx );
19301930
19311931 device_create_info = {
19321932 vk::DeviceCreateFlags (),
@@ -1953,7 +1953,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19531953
19541954 device->buffer_type = {
19551955 /* .iface = */ ggml_backend_vk_buffer_type_interface,
1956- /* .device = */ nullptr ,
1956+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), idx) ,
19571957 /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name , device },
19581958 };
19591959
@@ -6363,7 +6363,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
63636363 /* .get_alloc_size = */ ggml_backend_cpu_buffer_type ()->iface .get_alloc_size ,
63646364 /* .is_host = */ ggml_backend_cpu_buffer_type ()->iface .is_host ,
63656365 },
6366- /* .device = */ nullptr ,
6366+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), 0 ) ,
63676367 /* .context = */ nullptr ,
63686368 };
63696369
@@ -6566,9 +6566,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
65666566 UNUSED (backend);
65676567}
65686568
6569- static bool ggml_backend_vk_supports_op (ggml_backend_t backend, const ggml_tensor * op) {
6570- // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
6569+ // TODO: enable async and synchronize
6570+ static ggml_backend_i ggml_backend_vk_interface = {
6571+ /* .get_name = */ ggml_backend_vk_name,
6572+ /* .free = */ ggml_backend_vk_free,
6573+ /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6574+ /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6575+ /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6576+ /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6577+ /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6578+ /* .graph_plan_create = */ NULL ,
6579+ /* .graph_plan_free = */ NULL ,
6580+ /* .graph_plan_update = */ NULL ,
6581+ /* .graph_plan_compute = */ NULL ,
6582+ /* .graph_compute = */ ggml_backend_vk_graph_compute,
6583+ /* .supports_op = */ NULL ,
6584+ /* .supports_buft = */ NULL ,
6585+ /* .offload_op = */ NULL ,
6586+ /* .event_record = */ NULL ,
6587+ /* .event_wait = */ NULL ,
6588+ };
6589+
6590+ static ggml_guid_t ggml_backend_vk_guid () {
6591+ static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6592+ return &guid;
6593+ }
6594+
6595+ ggml_backend_t ggml_backend_vk_init (size_t dev_num) {
6596+ VK_LOG_DEBUG (" ggml_backend_vk_init(" << dev_num << " )" );
6597+
6598+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6599+ ggml_vk_init (ctx, dev_num);
6600+
6601+ ggml_backend_t vk_backend = new ggml_backend {
6602+ /* .guid = */ ggml_backend_vk_guid (),
6603+ /* .interface = */ ggml_backend_vk_interface,
6604+ /* .device = */ ggml_backend_reg_dev_get (ggml_backend_vk_reg (), dev_num),
6605+ /* .context = */ ctx,
6606+ };
6607+
6608+ return vk_backend;
6609+ }
6610+
6611+ bool ggml_backend_is_vk (ggml_backend_t backend) {
6612+ return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6613+ }
6614+
6615+ int ggml_backend_vk_get_device_count () {
6616+ return ggml_vk_get_device_count ();
6617+ }
6618+
6619+ void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6620+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6621+ int dev_idx = vk_instance.device_indices [device];
6622+ ggml_vk_get_device_description (dev_idx, description, description_size);
6623+ }
6624+
6625+ void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6626+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6627+
6628+ vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6629+
6630+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6631+
6632+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps ) {
6633+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6634+ *total = heap.size ;
6635+ *free = heap.size ;
6636+ break ;
6637+ }
6638+ }
6639+ }
6640+
6641+ // ////////////////////////
6642+
6643+ struct ggml_backend_vk_device_context {
6644+ int device;
6645+ std::string name;
6646+ std::string description;
6647+ };
6648+
6649+ static const char * ggml_backend_vk_device_get_name (ggml_backend_dev_t dev) {
6650+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6651+ return ctx->name .c_str ();
6652+ }
6653+
6654+ static const char * ggml_backend_vk_device_get_description (ggml_backend_dev_t dev) {
6655+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6656+ return ctx->description .c_str ();
6657+ }
6658+
6659+ static void ggml_backend_vk_device_get_memory (ggml_backend_dev_t device, size_t * free, size_t * total) {
6660+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context ;
6661+ ggml_backend_vk_get_device_memory (ctx->device , free, total);
6662+ }
6663+
6664+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type (ggml_backend_dev_t dev) {
6665+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6666+ return ggml_backend_vk_buffer_type (ctx->device );
6667+ }
6668+
6669+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type (ggml_backend_dev_t dev) {
6670+ UNUSED (dev);
6671+ return ggml_backend_vk_host_buffer_type ();
6672+ }
65716673
6674+ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type (ggml_backend_dev_t dev) {
6675+ UNUSED (dev);
6676+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6677+ }
6678+
6679+ static void ggml_backend_vk_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
6680+ props->name = ggml_backend_vk_device_get_name (dev);
6681+ props->description = ggml_backend_vk_device_get_description (dev);
6682+ props->type = ggml_backend_vk_device_get_type (dev);
6683+ ggml_backend_vk_device_get_memory (dev, &props->memory_free , &props->memory_total );
6684+ props->caps = {
6685+ /* async */ false ,
6686+ /* host_buffer */ true ,
6687+ /* events */ false ,
6688+ };
6689+ }
6690+
6691+ static ggml_backend_t ggml_backend_vk_device_init (ggml_backend_dev_t dev, const char * params) {
6692+ UNUSED (params);
6693+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6694+ return ggml_backend_vk_init (ctx->device );
6695+ }
6696+
6697+ static bool ggml_backend_vk_device_supports_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
65726698 switch (op->op ) {
65736699 case GGML_OP_UNARY:
65746700 switch (ggml_get_unary_op (op)) {
@@ -6686,97 +6812,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
66866812 return false ;
66876813 }
66886814
6689- UNUSED (backend);
6690- }
6691-
6692- static bool ggml_backend_vk_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
6693- const int min_batch_size = 32 ;
6694-
6695- return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6696- (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6697-
6698- UNUSED (backend);
6815+ UNUSED (dev);
66996816}
67006817
6701- static bool ggml_backend_vk_supports_buft ( ggml_backend_t backend , ggml_backend_buffer_type_t buft) {
6818+ static bool ggml_backend_vk_device_supports_buft ( ggml_backend_dev_t dev , ggml_backend_buffer_type_t buft) {
67026819 if (buft->iface .get_name != ggml_backend_vk_buffer_type_name) {
67036820 return false ;
67046821 }
67056822
6823+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
67066824 ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context ;
6707- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
6708-
6709- return buft_ctx->device == ctx->device ;
6710- }
6711-
6712- // TODO: enable async and synchronize
6713- static ggml_backend_i ggml_backend_vk_interface = {
6714- /* .get_name = */ ggml_backend_vk_name,
6715- /* .free = */ ggml_backend_vk_free,
6716- /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6717- /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6718- /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6719- /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6720- /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6721- /* .graph_plan_create = */ NULL ,
6722- /* .graph_plan_free = */ NULL ,
6723- /* .graph_plan_update = */ NULL ,
6724- /* .graph_plan_compute = */ NULL ,
6725- /* .graph_compute = */ ggml_backend_vk_graph_compute,
6726- /* .supports_op = */ ggml_backend_vk_supports_op,
6727- /* .supports_buft = */ ggml_backend_vk_supports_buft,
6728- /* .offload_op = */ ggml_backend_vk_offload_op,
6729- /* .event_record = */ NULL ,
6730- /* .event_wait = */ NULL ,
6731- };
67326825
6733- static ggml_guid_t ggml_backend_vk_guid () {
6734- static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6735- return &guid;
6826+ return buft_ctx->device ->idx == ctx->device ;
67366827}
67376828
6738- ggml_backend_t ggml_backend_vk_init ( size_t dev_num ) {
6739- VK_LOG_DEBUG ( " ggml_backend_vk_init( " << dev_num << " ) " ) ;
6829+ static bool ggml_backend_vk_device_offload_op ( ggml_backend_dev_t dev, const ggml_tensor * op ) {
6830+ const int min_batch_size = 32 ;
67406831
6741- ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6742- ggml_vk_init (ctx, dev_num );
6832+ return (op-> ne [ 1 ] >= min_batch_size && op-> op != GGML_OP_GET_ROWS) ||
6833+ (op-> ne [ 2 ] >= min_batch_size && op-> op == GGML_OP_MUL_MAT_ID );
67436834
6744- ggml_backend_t vk_backend = new ggml_backend {
6745- /* .guid = */ ggml_backend_vk_guid (),
6746- /* .interface = */ ggml_backend_vk_interface,
6747- /* .device = */ nullptr ,
6748- /* .context = */ ctx,
6749- };
6835+ UNUSED (dev);
6836+ }
6837+
6838+ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
6839+ /* .get_name = */ ggml_backend_vk_device_get_name,
6840+ /* .get_description = */ ggml_backend_vk_device_get_description,
6841+ /* .get_memory = */ ggml_backend_vk_device_get_memory,
6842+ /* .get_type = */ ggml_backend_vk_device_get_type,
6843+ /* .get_props = */ ggml_backend_vk_device_get_props,
6844+ /* .init_backend = */ ggml_backend_vk_device_init,
6845+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
6846+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
6847+ /* .buffer_from_host_ptr = */ NULL ,
6848+ /* .supports_op = */ ggml_backend_vk_device_supports_op,
6849+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
6850+ /* .offload_op = */ ggml_backend_vk_device_offload_op,
6851+ /* .event_new = */ NULL ,
6852+ /* .event_free = */ NULL ,
6853+ /* .event_synchronize = */ NULL ,
6854+ };
67506855
6751- return vk_backend;
6856+ static const char * ggml_backend_vk_reg_get_name (ggml_backend_reg_t reg) {
6857+ UNUSED (reg);
6858+ return GGML_VK_NAME;
67526859}
67536860
6754- bool ggml_backend_is_vk (ggml_backend_t backend) {
6755- return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6861+ static size_t ggml_backend_vk_reg_get_device_count (ggml_backend_reg_t reg) {
6862+ UNUSED (reg);
6863+ return ggml_backend_vk_get_device_count ();
67566864}
67576865
6758- int ggml_backend_vk_get_device_count () {
6759- return ggml_vk_get_device_count ();
6760- }
6866+ static ggml_backend_dev_t ggml_backend_vk_reg_get_device (ggml_backend_reg_t reg, size_t device) {
6867+ static std::vector<ggml_backend_dev_t > devices;
67616868
6762- void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6763- ggml_vk_get_device_description (device, description, description_size);
6764- }
6869+ static bool initialized = false ;
67656870
6766- void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6767- GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6871+ {
6872+ static std::mutex mutex;
6873+ std::lock_guard<std::mutex> lock (mutex);
6874+ if (!initialized) {
6875+ for (size_t i = 0 ; i < ggml_backend_vk_get_device_count (); i++) {
6876+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
6877+ char desc[256 ];
6878+ ggml_backend_vk_get_device_description (i, desc, sizeof (desc));
6879+ ctx->device = i;
6880+ ctx->name = GGML_VK_NAME + std::to_string (i);
6881+ ctx->description = desc;
6882+ devices.push_back (new ggml_backend_device {
6883+ /* .iface = */ ggml_backend_vk_device_i,
6884+ /* .reg = */ reg,
6885+ /* .context = */ ctx,
6886+ });
6887+ }
6888+ initialized = true ;
6889+ }
6890+ }
67686891
6769- vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6892+ GGML_ASSERT (device < devices.size ());
6893+ return devices[device];
6894+ }
67706895
6771- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6896+ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6897+ /* .get_name = */ ggml_backend_vk_reg_get_name,
6898+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
6899+ /* .get_device = */ ggml_backend_vk_reg_get_device,
6900+ /* .get_proc_address = */ NULL ,
6901+ };
67726902
6773- for ( const vk::MemoryHeap& heap : memprops. memoryHeaps ) {
6774- if (heap. flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6775- *total = heap. size ;
6776- *free = heap. size ;
6777- break ;
6778- }
6779- }
6903+ ggml_backend_reg_t ggml_backend_vk_reg ( ) {
6904+ static ggml_backend_reg reg = {
6905+ /* .iface = */ ggml_backend_vk_reg_i,
6906+ /* .context = */ nullptr ,
6907+ } ;
6908+
6909+ return ®
67806910}
67816911
67826912// Extension availability
0 commit comments