88
99#import < Metal/Metal.h>
1010
11+ #ifdef __cplusplus
12+ #include < array>
13+ #include < map>
14+ #include < mutex>
15+ #include < vector>
16+ #endif
17+
1118#undef MIN
1219#undef MAX
1320#define MIN (a, b ) ((a) < (b) ? (a) : (b))
@@ -1698,6 +1705,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
16981705 id rset;
16991706};
17001707
1708+ // Helper function to calculate tensor size for split buffers
1709+ static size_t ggml_nbytes_split (const struct ggml_tensor * tensor, int nrows_split) {
1710+ // Calculate the size based on the number of rows in the split
1711+ return nrows_split * ggml_row_size (tensor->type , tensor->ne [0 ]);
1712+ }
1713+
17011714// rset init
17021715static bool ggml_backend_metal_buffer_rset_init (
17031716 struct ggml_backend_metal_buffer_context * ctx,
@@ -6579,6 +6592,9 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
65796592}
65806593
65816594static void * ggml_backend_metal_get_proc_address (ggml_backend_reg_t reg, const char * name) {
6595+ if (strcmp (name, " ggml_backend_split_buffer_type" ) == 0 ) {
6596+ return (void *)ggml_backend_metal_split_buffer_type;
6597+ }
65826598 if (strcmp (name, " ggml_backend_get_features" ) == 0 ) {
65836599 return (void *)ggml_backend_metal_get_features;
65846600 }
@@ -6599,6 +6615,333 @@ static void ggml_metal_cleanup(void) {
65996615 ggml_backend_metal_device_rel (&g_ggml_ctx_dev_main);
66006616}
66016617
6618+ //
6619+ // Metal split buffer implementation
6620+ //
6621+
6622+ #ifdef __cplusplus
6623+
6624+ #define MATRIX_ROW_PADDING 512 // As defined in CUDA implementation
6625+
6626+ // Metal equivalent of ggml_tensor_extra_gpu
6627+ struct ggml_tensor_extra_metal {
6628+ // Metal buffers for each device (Metal only supports one device in current implementation)
6629+ // But we'll keep the array structure for consistency with CUDA
6630+ id <MTLBuffer > data_device[1 ]; // Metal only supports one device currently
6631+ };
6632+
6633+ // Buffer type context
6634+ struct ggml_backend_metal_split_buffer_type_context {
6635+ int main_device;
6636+ std::array<float , 1 > tensor_split; // Metal only supports one device, but keeping array for API consistency
6637+ std::string name;
6638+ };
6639+
6640+ // Buffer context
6641+ struct ggml_backend_metal_split_buffer_context {
6642+ ~ggml_backend_metal_split_buffer_context () {
6643+ for (ggml_tensor_extra_metal * extra : tensor_extras) {
6644+ // Clean up Metal buffers
6645+ if (extra->data_device [0 ] != nullptr ) {
6646+ [extra->data_device[0 ] release ];
6647+ }
6648+ delete extra;
6649+ }
6650+ }
6651+
6652+ std::vector<ggml_tensor_extra_metal *> tensor_extras;
6653+ };
6654+
6655+ // Helper function to calculate tensor size for split buffers
6656+ static size_t ggml_nbytes_split (const struct ggml_tensor * tensor, int nrows_split) {
6657+ // Calculate the size based on the number of rows in the split
6658+ return nrows_split * ggml_row_size (tensor->type , tensor->ne [0 ]);
6659+ }
6660+
6661+ // Tensor split calculation
6662+ static void get_row_split (int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float , 1 > & tensor_split, int id ) {
6663+ // For Metal, we only have one device, so all rows go to device 0
6664+ if (id == 0 ) {
6665+ *row_low = 0 ;
6666+ *row_high = tensor->ne [1 ];
6667+ } else {
6668+ *row_low = 0 ;
6669+ *row_high = 0 ;
6670+ }
6671+
6672+ GGML_UNUSED (tensor_split);
6673+ }
6674+
6675+ // Buffer free function
6676+ static void ggml_backend_metal_split_buffer_free_buffer (ggml_backend_buffer_t buffer) {
6677+ ggml_backend_metal_split_buffer_context * ctx = (ggml_backend_metal_split_buffer_context *)buffer->context ;
6678+ delete ctx;
6679+ }
6680+
6681+ // Buffer get base function
6682+ static void * ggml_backend_metal_split_buffer_get_base (ggml_backend_buffer_t buffer) {
6683+ // The pointers are stored in the tensor extras, this is just a dummy address
6684+ return (void *)0x1000 ;
6685+
6686+ GGML_UNUSED (buffer);
6687+ }
6688+
6689+ // Buffer init tensor function
6690+ static enum ggml_status ggml_backend_metal_split_buffer_init_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
6691+ GGML_ASSERT (tensor->view_src == nullptr ); // views of split tensors are not supported
6692+ GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
6693+
6694+ ggml_backend_metal_split_buffer_context * ctx = (ggml_backend_metal_split_buffer_context *)buffer->context ;
6695+ ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
6696+
6697+ const int64_t ne0 = tensor->ne [0 ];
6698+
6699+ ggml_tensor_extra_metal * extra = new ggml_tensor_extra_metal{};
6700+ ctx->tensor_extras .push_back (extra);
6701+
6702+ // For Metal, we only have one device
6703+ int id = 0 ;
6704+ int64_t row_low, row_high;
6705+ get_row_split (&row_low, &row_high, tensor, buft_ctx->tensor_split , id );
6706+
6707+ int64_t nrows_split = row_high - row_low;
6708+ if (nrows_split == 0 ) {
6709+ tensor->extra = extra;
6710+ return GGML_STATUS_SUCCESS;
6711+ }
6712+
6713+ size_t size = ggml_nbytes_split (tensor, nrows_split);
6714+ const size_t original_size = size;
6715+
6716+ // Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6717+ if (ne0 % MATRIX_ROW_PADDING != 0 ) {
6718+ size += ggml_row_size (tensor->type , MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6719+ }
6720+
6721+ // Get Metal device context
6722+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buffer->buft ->device ->context ;
6723+ id <MTLDevice > device = ctx_dev->mtl_device ;
6724+
6725+ // Allocate Metal buffer
6726+ extra->data_device [id ] = [device newBufferWithLength: size options: MTLResourceStorageModePrivate ];
6727+
6728+ // Initialize buffer with zeros
6729+ memset ([extra->data_device[id ] contents ], 0 , size);
6730+
6731+ tensor->extra = extra;
6732+ return GGML_STATUS_SUCCESS;
6733+ }
6734+
6735+ // Buffer set tensor function
6736+ static void ggml_backend_metal_split_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
6737+ // Split tensors must always be set in their entirety at once
6738+ GGML_ASSERT (offset == 0 );
6739+ GGML_ASSERT (size == ggml_nbytes (tensor));
6740+ GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
6741+
6742+ ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
6743+
6744+ const int64_t ne0 = tensor->ne [0 ];
6745+ const size_t nb1 = tensor->nb [1 ];
6746+ ggml_tensor_extra_metal * extra = (ggml_tensor_extra_metal *)tensor->extra ;
6747+
6748+ // For Metal, we only have one device
6749+ int id = 0 ;
6750+ int64_t row_low, row_high;
6751+ get_row_split (&row_low, &row_high, tensor, buft_ctx->tensor_split , id );
6752+
6753+ int64_t nrows_split = row_high - row_low;
6754+ if (nrows_split == 0 ) {
6755+ return ;
6756+ }
6757+
6758+ const size_t offset_split = row_low * nb1;
6759+ size_t alloc_size = ggml_nbytes_split (tensor, nrows_split);
6760+ const size_t original_size = alloc_size;
6761+
6762+ // Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6763+ if (ne0 % MATRIX_ROW_PADDING != 0 ) {
6764+ alloc_size += ggml_row_size (tensor->type , MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6765+ }
6766+
6767+ const char * buf_host = (const char *)data + offset_split;
6768+
6769+ // Copy data to Metal buffer
6770+ memcpy ([extra->data_device[id ] contents ], buf_host, original_size);
6771+ }
6772+
6773+ // Buffer get tensor function
6774+ static void ggml_backend_metal_split_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
6775+ // Split tensors must always be retrieved in their entirety at once
6776+ GGML_ASSERT (offset == 0 );
6777+ GGML_ASSERT (size == ggml_nbytes (tensor));
6778+ GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
6779+
6780+ ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
6781+
6782+ const int64_t ne0 = tensor->ne [0 ];
6783+ const size_t nb1 = tensor->nb [1 ];
6784+ ggml_tensor_extra_metal * extra = (ggml_tensor_extra_metal *)tensor->extra ;
6785+
6786+ // For Metal, we only have one device
6787+ int id = 0 ;
6788+ int64_t row_low, row_high;
6789+ get_row_split (&row_low, &row_high, tensor, buft_ctx->tensor_split , id );
6790+
6791+ int64_t nrows_split = row_high - row_low;
6792+ if (nrows_split == 0 ) {
6793+ return ;
6794+ }
6795+
6796+ const size_t offset_split = row_low * nb1;
6797+ size_t alloc_size = ggml_nbytes_split (tensor, nrows_split);
6798+ const size_t original_size = alloc_size;
6799+
6800+ // Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6801+ if (ne0 % MATRIX_ROW_PADDING != 0 ) {
6802+ alloc_size += ggml_row_size (tensor->type , MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6803+ }
6804+
6805+ char * buf_host = (char *)data + offset_split;
6806+
6807+ // Copy data from Metal buffer
6808+ memcpy (buf_host, [extra->data_device[id ] contents ], original_size);
6809+ }
6810+
6811+ // Buffer clear function
6812+ static void ggml_backend_metal_split_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
6813+ GGML_UNUSED (buffer);
6814+ GGML_UNUSED (value);
6815+ // Not implemented for split buffers
6816+ }
6817+
6818+ // Buffer interface
6819+ static const ggml_backend_buffer_i ggml_backend_metal_split_buffer_interface = {
6820+ /* .free_buffer = */ ggml_backend_metal_split_buffer_free_buffer,
6821+ /* .get_base = */ ggml_backend_metal_split_buffer_get_base,
6822+ /* .init_tensor = */ ggml_backend_metal_split_buffer_init_tensor,
6823+ /* .memset_tensor = */ NULL ,
6824+ /* .set_tensor = */ ggml_backend_metal_split_buffer_set_tensor,
6825+ /* .get_tensor = */ ggml_backend_metal_split_buffer_get_tensor,
6826+ /* .cpy_tensor = */ NULL ,
6827+ /* .clear = */ ggml_backend_metal_split_buffer_clear,
6828+ /* .reset = */ NULL ,
6829+ };
6830+
6831+ // Buffer type interface functions
6832+ static const char * ggml_backend_metal_split_buffer_type_get_name (ggml_backend_buffer_type_t buft) {
6833+ ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context ;
6834+ return ctx->name .c_str ();
6835+ }
6836+
6837+ static bool ggml_backend_buft_is_metal_split (ggml_backend_buffer_type_t buft) {
6838+ return buft->iface .get_name == ggml_backend_metal_split_buffer_type_get_name;
6839+ }
6840+
6841+ static ggml_backend_buffer_t ggml_backend_metal_split_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
6842+ // Since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
6843+ // Instead, we allocate them for each tensor separately in init_tensor
6844+ // However, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
6845+ // as returned by get_alloc_size. This limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
6846+ ggml_backend_metal_split_buffer_context * ctx = new ggml_backend_metal_split_buffer_context ();
6847+
6848+ return ggml_backend_buffer_init (buft, ggml_backend_metal_split_buffer_interface, ctx, size);
6849+ }
6850+
6851+ static size_t ggml_backend_metal_split_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
6852+ return 128 ;
6853+
6854+ GGML_UNUSED (buft);
6855+ }
6856+
6857+ static size_t ggml_backend_metal_split_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
6858+ ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context ;
6859+ GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
6860+
6861+ size_t total_size = 0 ;
6862+
6863+ const int64_t ne0 = tensor->ne [0 ];
6864+
6865+ // For Metal, we only have one device
6866+ int id = 0 ;
6867+ int64_t row_low, row_high;
6868+ get_row_split (&row_low, &row_high, tensor, ctx->tensor_split , id );
6869+
6870+ int64_t nrows_split = row_high - row_low;
6871+ if (nrows_split == 0 ) {
6872+ return total_size;
6873+ }
6874+
6875+ total_size += ggml_nbytes_split (tensor, nrows_split);
6876+
6877+ // Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6878+ if (ne0 % MATRIX_ROW_PADDING != 0 ) {
6879+ total_size += ggml_row_size (tensor->type , MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6880+ }
6881+
6882+ return total_size;
6883+ }
6884+
6885+ static bool ggml_backend_metal_split_buffer_type_is_host (ggml_backend_buffer_type_t buft) {
6886+ return false ;
6887+
6888+ GGML_UNUSED (buft);
6889+ }
6890+
6891+ // Buffer type interface
6892+ static const ggml_backend_buffer_type_i ggml_backend_metal_split_buffer_type_interface = {
6893+ /* .get_name = */ ggml_backend_metal_split_buffer_type_get_name,
6894+ /* .alloc_buffer = */ ggml_backend_metal_split_buffer_type_alloc_buffer,
6895+ /* .get_alignment = */ ggml_backend_metal_split_buffer_type_get_alignment,
6896+ /* .get_max_size = */ NULL , // defaults to SIZE_MAX
6897+ /* .get_alloc_size = */ ggml_backend_metal_split_buffer_type_get_alloc_size,
6898+ /* .is_host = */ ggml_backend_metal_split_buffer_type_is_host,
6899+ };
6900+
6901+ #endif // __cplusplus
6902+
6903+ // Main function to create Metal split buffer type
6904+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_split_buffer_type (int main_device, const float * tensor_split) {
6905+ GGML_LOG_INFO (" %s : creating Metal split buffer type, main_device=%d \n " , __func__, main_device);
6906+ #ifdef __cplusplus
6907+ static std::mutex mutex;
6908+ std::lock_guard<std::mutex> lock (mutex);
6909+
6910+ static std::map<std::pair<int , std::array<float , 1 >>, struct ggml_backend_buffer_type> buft_map;
6911+
6912+ std::array<float , 1 > tensor_split_arr = {};
6913+
6914+ // For Metal, we only support one device, so we simplify the tensor split logic
6915+ tensor_split_arr[0 ] = 1 .0f ; // All tensors go to the single Metal device
6916+
6917+ auto it = buft_map.find ({main_device, tensor_split_arr});
6918+ if (it != buft_map.end ()) {
6919+ GGML_LOG_INFO (" %s : returning existing buffer type\n " , __func__);
6920+ return &it->second ;
6921+ }
6922+
6923+ auto * ctx = new ggml_backend_metal_split_buffer_type_context{
6924+ main_device,
6925+ tensor_split_arr,
6926+ std::string (" Metal_Split" ),
6927+ };
6928+
6929+ struct ggml_backend_buffer_type buft {
6930+ /* .iface = */ ggml_backend_metal_split_buffer_type_interface,
6931+ /* .device = */ ggml_backend_reg_dev_get (ggml_backend_metal_reg (), main_device),
6932+ /* .context = */ ctx,
6933+ };
6934+
6935+ auto result = buft_map.emplace (std::make_pair (main_device, tensor_split_arr), buft);
6936+ GGML_LOG_INFO (" %s : created new Metal split buffer type\n " , __func__);
6937+ return &result.first ->second ;
6938+ #else
6939+ // For C builds, return the regular Metal buffer type
6940+ GGML_LOG_INFO (" %s : C build, returning regular Metal buffer type\n " , __func__);
6941+ return ggml_backend_metal_buffer_type ();
6942+ #endif
6943+ }
6944+
66026945// TODO: make thread-safe
66036946ggml_backend_reg_t ggml_backend_metal_reg (void ) {
66046947 ggml_backend_metal_device_acq (&g_ggml_ctx_dev_main);
0 commit comments