@@ -5918,10 +5918,13 @@ static void ggml_backend_metal_split_buffer_context_free(struct ggml_backend_met
59185918static void get_row_split (int64_t * row_low, int64_t * row_high, const struct ggml_tensor * tensor, const float tensor_split[1 ], int id ) {
59195919 GGML_LOG_DEBUG (" %s : tensor '%s ', id=%d , ne[1]=%lld \n " , __func__, tensor->name , id , tensor->ne [1 ]);
59205920
5921+ const int64_t nrows = ggml_nrows (tensor);
5922+
59215923 // For Metal, we only have one device, so all rows go to device 0
59225924 if (id == 0 ) {
5923- *row_low = 0 ;
5924- *row_high = tensor->ne [1 ];
5925+ // Use the tensor_split value to determine how much of the tensor goes to this device
5926+ *row_low = id == 0 ? 0 : (int64_t )(nrows * tensor_split[id ]);
5927+ *row_high = id == 0 ? nrows : (int64_t )(nrows * tensor_split[id ]);
59255928 GGML_LOG_DEBUG (" %s : assigning rows [%lld , %lld ] to device %d \n " , __func__, *row_low, *row_high, id );
59265929 } else {
59275930 *row_low = 0 ;
@@ -5930,7 +5933,6 @@ static void get_row_split(int64_t * row_low, int64_t * row_high, const struct gg
59305933 }
59315934
59325935 GGML_LOG_DEBUG (" %s : tensor_split[0] = %f \n " , __func__, (double )tensor_split[0 ]);
5933- GGML_UNUSED (tensor_split);
59345936}
59355937
59365938// Buffer free function
@@ -6053,7 +6055,8 @@ static void ggml_backend_metal_split_buffer_set_tensor(
60536055 GGML_ASSERT (size == ggml_nbytes (tensor));
60546056 GGML_ASSERT (ggml_is_contiguous (tensor));
60556057
6056- struct ggml_backend_metal_split_buffer_context *ctx = (struct ggml_backend_metal_split_buffer_type_context *) buffer->buft ->context ;
6058+ struct ggml_backend_metal_split_buffer_context *ctx = (struct ggml_backend_metal_split_buffer_context *) buffer->context ;
6059+ struct ggml_backend_metal_split_buffer_type_context *buft_ctx = (struct ggml_backend_metal_split_buffer_type_context *) buffer->buft ->context ;
60576060 const int64_t ne0 = tensor->ne [0 ];
60586061 const size_t nb1 = tensor->nb [1 ];
60596062 struct ggml_tensor_extra_metal *extra = (struct ggml_tensor_extra_metal *) tensor->extra ;
@@ -6063,7 +6066,7 @@ static void ggml_backend_metal_split_buffer_set_tensor(
60636066 for (int id = 0 ; id < device_count; ++id) {
60646067 const float id_ = 1.0f;
60656068 int64_t row_low = 0, row_high = 0;
6066- get_row_split(&row_low, &row_high, tensor, &id_ , id);
6069+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split , id );
60676070 int64_t nrows = row_high - row_low;
60686071 if (nrows <= 0 ) {
60696072 continue ;
@@ -6081,7 +6084,7 @@ static void ggml_backend_metal_split_buffer_set_tensor(
60816084 // Copy the slice into the Metal buffer (contents pointer to GPU memory)
60826085 memcpy ([extra->data_device[id ] contents ], buf_host, original_size);
60836086 // On macOS, inform Metal that buffer range was modified so GPU sees new data:contentReference[oaicite:2]{index=2}
6084- // [extra->data_device[id] didModifyRange:NSMakeRange(0, original_size)];
6087+ [extra->data_device[id ] didModifyRange: NSMakeRange (0 , original_size)];
60856088 }
60866089}
60876090
@@ -6093,6 +6096,7 @@ static void ggml_backend_metal_split_buffer_get_tensor(ggml_backend_buffer_t buf
60936096 GGML_ASSERT (size == ggml_nbytes (tensor));
60946097 GGML_ASSERT (ggml_is_contiguous (tensor) && " split buffers only supported for contiguous tensors" );
60956098
6099+ struct ggml_backend_metal_split_buffer_context * ctx = (struct ggml_backend_metal_split_buffer_context *)buffer->context ;
60966100 struct ggml_backend_metal_split_buffer_type_context * buft_ctx = (struct ggml_backend_metal_split_buffer_type_context *)buffer->buft ->context ;
60976101
60986102 const int64_t ne0 = tensor->ne [0 ];
@@ -6160,6 +6164,10 @@ static ggml_backend_buffer_t ggml_backend_split_buffer_type_alloc_buffer(ggml_ba
61606164 // However, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
61616165 // as returned by get_alloc_size. This limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
61626166 struct ggml_backend_metal_split_buffer_context * ctx = calloc (1 , sizeof (struct ggml_backend_metal_split_buffer_context));
6167+ if (ctx == NULL ) {
6168+ GGML_LOG_ERROR (" %s : failed to allocate split buffer context\n " , __func__);
6169+ return NULL ;
6170+ }
61636171
61646172 return ggml_backend_buffer_init (buft, ggml_backend_metal_split_buffer_interface, ctx, size);
61656173}
@@ -6227,10 +6235,17 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_split_buffer_type(int m
62276235 }
62286236
62296237 ctx->main_device = main_device;
6230- ctx->tensor_split [0 ] = 1 .0f ; // All tensors go to the single Metal device
6231- ctx->name = " Metal_Split" ;
62326238
6233- GGML_LOG_DEBUG (" %s : tensor_split[0] = %f \n " , __func__, (double )ctx->tensor_split [0 ]);
6239+ // Properly handle tensor split values
6240+ if (tensor_split != NULL ) {
6241+ ctx->tensor_split [0 ] = tensor_split[0 ];
6242+ GGML_LOG_DEBUG (" %s : tensor_split[0] = %f (from input)\n " , __func__, (double )ctx->tensor_split [0 ]);
6243+ } else {
6244+ ctx->tensor_split [0 ] = 1 .0f ; // All tensors go to the single Metal device
6245+ GGML_LOG_DEBUG (" %s : tensor_split[0] = %f (default)\n " , __func__, (double )ctx->tensor_split [0 ]);
6246+ }
6247+
6248+ ctx->name = " Metal_Split" ;
62346249
62356250 // Allocate a new buffer type structure each time
62366251 struct ggml_backend_buffer_type * buft = calloc (1 , sizeof (struct ggml_backend_buffer_type));
0 commit comments