@@ -5878,6 +5878,20 @@ static enum ggml_status ggml_metal_graph_compute(
58785878
58795879// backend interface
58805880
5881+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_split_buffer_type (int main_device, const float * tensor_split) {
5882+ GGML_LOG_INFO (" %s : creating Metal split buffer type, main_device=%d \n " , __func__, main_device);
5883+
5884+ // For Metal split buffer type, we return the regular Metal buffer type
5885+ // since Metal currently only supports one device
5886+ ggml_backend_buffer_type_t buft = ggml_backend_metal_buffer_type ();
5887+ GGML_LOG_INFO (" %s : returning Metal buffer type\n " , __func__);
5888+ return buft;
5889+
5890+ GGML_UNUSED (main_device);
5891+ GGML_UNUSED (tensor_split);
5892+ }
5893+
5894+
58815895static void ggml_backend_metal_buffer_free_buffer (ggml_backend_buffer_t buffer) {
58825896 struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context ;
58835897
@@ -6593,7 +6607,7 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
65936607
65946608static void * ggml_backend_metal_get_proc_address (ggml_backend_reg_t reg, const char * name) {
65956609 if (strcmp (name, " ggml_backend_split_buffer_type" ) == 0 ) {
6596- return (void *)ggml_backend_metal_split_buffer_type ;
6610+ return (void *)ggml_backend_split_buffer_type ;
65976611 }
65986612 if (strcmp (name, " ggml_backend_get_features" ) == 0 ) {
65996613 return (void *)ggml_backend_metal_get_features;
@@ -6631,7 +6645,7 @@ static void ggml_metal_cleanup(void) {
66316645};
66326646
66336647// Buffer type context
6634- struct ggml_backend_metal_split_buffer_type_context {
6648+ struct ggml_backend_split_buffer_type_context {
66356649 int main_device;
66366650 std::array<float , 1 > tensor_split; // Metal only supports one device, but keeping array for API consistency
66376651 std::string name;
@@ -6660,6 +6674,7 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl
66606674
66616675// Tensor split calculation
66626676static void get_row_split (int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float , 1 > & tensor_split, int id ) {
6677+ GGML_LOG_INFO (" Returning Row Splits.\n " );
66636678 // For Metal, we only have one device, so all rows go to device 0
66646679 if (id == 0 ) {
66656680 *row_low = 0 ;
@@ -6692,7 +6707,7 @@ static enum ggml_status ggml_backend_metal_split_buffer_init_tensor(ggml_backend
66926707 GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
66936708
66946709 ggml_backend_metal_split_buffer_context * ctx = (ggml_backend_metal_split_buffer_context *)buffer->context ;
6695- ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
6710+ ggml_backend_split_buffer_type_context * buft_ctx = (ggml_backend_split_buffer_type_context *)buffer->buft ->context ;
66966711
66976712 const int64_t ne0 = tensor->ne [0 ];
66986713
@@ -6739,7 +6754,7 @@ static void ggml_backend_metal_split_buffer_set_tensor(ggml_backend_buffer_t buf
67396754 GGML_ASSERT (size == ggml_nbytes (tensor));
67406755 GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
67416756
6742- ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
6757+ ggml_backend_split_buffer_type_context * buft_ctx = (ggml_backend_split_buffer_type_context *)buffer->buft ->context ;
67436758
67446759 const int64_t ne0 = tensor->ne [0 ];
67456760 const size_t nb1 = tensor->nb [1 ];
@@ -6777,7 +6792,7 @@ static void ggml_backend_metal_split_buffer_get_tensor(ggml_backend_buffer_t buf
67776792 GGML_ASSERT (size == ggml_nbytes (tensor));
67786793 GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
67796794
6780- ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
6795+ ggml_backend_split_buffer_type_context * buft_ctx = (ggml_backend_split_buffer_type_context *)buffer->buft ->context ;
67816796
67826797 const int64_t ne0 = tensor->ne [0 ];
67836798 const size_t nb1 = tensor->nb [1 ];
@@ -6829,16 +6844,16 @@ static void ggml_backend_metal_split_buffer_clear(ggml_backend_buffer_t buffer,
68296844};
68306845
68316846// Buffer type interface functions
6832- static const char * ggml_backend_metal_split_buffer_type_get_name (ggml_backend_buffer_type_t buft) {
6833- ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context ;
6847+ static const char * ggml_backend_split_buffer_type_get_name (ggml_backend_buffer_type_t buft) {
6848+ ggml_backend_split_buffer_type_context * ctx = (ggml_backend_split_buffer_type_context *)buft->context ;
68346849 return ctx->name .c_str ();
68356850}
68366851
68376852static bool ggml_backend_buft_is_metal_split (ggml_backend_buffer_type_t buft) {
6838- return buft->iface .get_name == ggml_backend_metal_split_buffer_type_get_name ;
6853+ return buft->iface .get_name == ggml_backend_split_buffer_type_get_name ;
68396854}
68406855
6841- static ggml_backend_buffer_t ggml_backend_metal_split_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
6856+ static ggml_backend_buffer_t ggml_backend_split_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
68426857 // Since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
68436858 // Instead, we allocate them for each tensor separately in init_tensor
68446859 // However, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
@@ -6848,14 +6863,14 @@ static ggml_backend_buffer_t ggml_backend_metal_split_buffer_type_alloc_buffer(g
68486863 return ggml_backend_buffer_init (buft, ggml_backend_metal_split_buffer_interface, ctx, size);
68496864}
68506865
6851- static size_t ggml_backend_metal_split_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
6866+ static size_t ggml_backend_split_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
68526867 return 128 ;
68536868
68546869 GGML_UNUSED (buft);
68556870}
68566871
6857- static size_t ggml_backend_metal_split_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
6858- ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context ;
6872+ static size_t ggml_backend_split_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
6873+ ggml_backend_split_buffer_type_context * ctx = (ggml_backend_split_buffer_type_context *)buft->context ;
68596874 GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
68606875
68616876 size_t total_size = 0 ;
@@ -6882,66 +6897,24 @@ static size_t ggml_backend_metal_split_buffer_type_get_alloc_size(ggml_backend_b
68826897 return total_size;
68836898}
68846899
6885- static bool ggml_backend_metal_split_buffer_type_is_host (ggml_backend_buffer_type_t buft) {
6900+ static bool ggml_backend_split_buffer_type_is_host (ggml_backend_buffer_type_t buft) {
68866901 return false ;
68876902
68886903 GGML_UNUSED (buft);
68896904}
68906905
68916906// Buffer type interface
6892- static const ggml_backend_buffer_type_i ggml_backend_metal_split_buffer_type_interface = {
6893- /* .get_name = */ ggml_backend_metal_split_buffer_type_get_name ,
6894- /* .alloc_buffer = */ ggml_backend_metal_split_buffer_type_alloc_buffer ,
6895- /* .get_alignment = */ ggml_backend_metal_split_buffer_type_get_alignment ,
6907+ static const ggml_backend_buffer_type_i ggml_backend_split_buffer_type_interface = {
6908+ /* .get_name = */ ggml_backend_split_buffer_type_get_name ,
6909+ /* .alloc_buffer = */ ggml_backend_split_buffer_type_alloc_buffer ,
6910+ /* .get_alignment = */ ggml_backend_split_buffer_type_get_alignment ,
68966911 /* .get_max_size = */ NULL , // defaults to SIZE_MAX
6897- /* .get_alloc_size = */ ggml_backend_metal_split_buffer_type_get_alloc_size ,
6898- /* .is_host = */ ggml_backend_metal_split_buffer_type_is_host ,
6912+ /* .get_alloc_size = */ ggml_backend_split_buffer_type_get_alloc_size ,
6913+ /* .is_host = */ ggml_backend_split_buffer_type_is_host ,
68996914};
69006915
69016916#endif // __cplusplus
69026917
6903- // Main function to create Metal split buffer type
6904- GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_split_buffer_type (int main_device, const float * tensor_split) {
6905- GGML_LOG_INFO (" %s : creating Metal split buffer type, main_device=%d \n " , __func__, main_device);
6906- #ifdef __cplusplus
6907- static std::mutex mutex;
6908- std::lock_guard<std::mutex> lock (mutex);
6909-
6910- static std::map<std::pair<int , std::array<float , 1 >>, struct ggml_backend_buffer_type> buft_map;
6911-
6912- std::array<float , 1 > tensor_split_arr = {};
6913-
6914- // For Metal, we only support one device, so we simplify the tensor split logic
6915- tensor_split_arr[0 ] = 1 .0f ; // All tensors go to the single Metal device
6916-
6917- auto it = buft_map.find ({main_device, tensor_split_arr});
6918- if (it != buft_map.end ()) {
6919- GGML_LOG_INFO (" %s : returning existing buffer type\n " , __func__);
6920- return &it->second ;
6921- }
6922-
6923- auto * ctx = new ggml_backend_metal_split_buffer_type_context{
6924- main_device,
6925- tensor_split_arr,
6926- std::string (" Metal_Split" ),
6927- };
6928-
6929- struct ggml_backend_buffer_type buft {
6930- /* .iface = */ ggml_backend_metal_split_buffer_type_interface,
6931- /* .device = */ ggml_backend_reg_dev_get (ggml_backend_metal_reg (), main_device),
6932- /* .context = */ ctx,
6933- };
6934-
6935- auto result = buft_map.emplace (std::make_pair (main_device, tensor_split_arr), buft);
6936- GGML_LOG_INFO (" %s : created new Metal split buffer type\n " , __func__);
6937- return &result.first ->second ;
6938- #else
6939- // For C builds, return the regular Metal buffer type
6940- GGML_LOG_INFO (" %s : C build, returning regular Metal buffer type\n " , __func__);
6941- return ggml_backend_metal_buffer_type ();
6942- #endif
6943- }
6944-
69456918// TODO: make thread-safe
69466919ggml_backend_reg_t ggml_backend_metal_reg (void ) {
69476920 ggml_backend_metal_device_acq (&g_ggml_ctx_dev_main);
0 commit comments