Skip to content

Commit b53f098

Browse files
committed
Implement ggml_backend_metal_split_buffer_type for Metal backend support split-mode row
1 parent ae355f6 commit b53f098

File tree

2 files changed

+345
-0
lines changed

2 files changed

+345
-0
lines changed

ggml/include/ggml-metal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backe
5151

5252
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
5353

54+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_split_buffer_type(int main_device, const float * tensor_split);
55+
5456
// helper to check if the device supports a specific family
5557
// ideally, the user code should be doing these checks
5658
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88

99
#import <Metal/Metal.h>
1010

11+
#ifdef __cplusplus
12+
#include <array>
13+
#include <map>
14+
#include <mutex>
15+
#include <vector>
16+
#endif
17+
1118
#undef MIN
1219
#undef MAX
1320
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -1698,6 +1705,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
16981705
id rset;
16991706
};
17001707

1708+
// Helper function to calculate tensor size for split buffers
1709+
static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
1710+
// Calculate the size based on the number of rows in the split
1711+
return nrows_split * ggml_row_size(tensor->type, tensor->ne[0]);
1712+
}
1713+
17011714
// rset init
17021715
static bool ggml_backend_metal_buffer_rset_init(
17031716
struct ggml_backend_metal_buffer_context * ctx,
@@ -6579,6 +6592,9 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
65796592
}
65806593

65816594
static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
6595+
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
6596+
return (void *)ggml_backend_metal_split_buffer_type;
6597+
}
65826598
if (strcmp(name, "ggml_backend_get_features") == 0) {
65836599
return (void *)ggml_backend_metal_get_features;
65846600
}
@@ -6599,6 +6615,333 @@ static void ggml_metal_cleanup(void) {
65996615
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
66006616
}
66016617

6618+
//
6619+
// Metal split buffer implementation
6620+
//
6621+
6622+
#ifdef __cplusplus
6623+
6624+
#define MATRIX_ROW_PADDING 512 // As defined in CUDA implementation
6625+
6626+
// Metal equivalent of ggml_tensor_extra_gpu
6627+
struct ggml_tensor_extra_metal {
6628+
// Metal buffers for each device (Metal only supports one device in current implementation)
6629+
// But we'll keep the array structure for consistency with CUDA
6630+
id<MTLBuffer> data_device[1]; // Metal only supports one device currently
6631+
};
6632+
6633+
// Buffer type context
6634+
struct ggml_backend_metal_split_buffer_type_context {
6635+
int main_device;
6636+
std::array<float, 1> tensor_split; // Metal only supports one device, but keeping array for API consistency
6637+
std::string name;
6638+
};
6639+
6640+
// Buffer context
6641+
struct ggml_backend_metal_split_buffer_context {
6642+
~ggml_backend_metal_split_buffer_context() {
6643+
for (ggml_tensor_extra_metal * extra : tensor_extras) {
6644+
// Clean up Metal buffers
6645+
if (extra->data_device[0] != nullptr) {
6646+
[extra->data_device[0] release];
6647+
}
6648+
delete extra;
6649+
}
6650+
}
6651+
6652+
std::vector<ggml_tensor_extra_metal *> tensor_extras;
6653+
};
6654+
6655+
// Helper function to calculate tensor size for split buffers
6656+
static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
6657+
// Calculate the size based on the number of rows in the split
6658+
return nrows_split * ggml_row_size(tensor->type, tensor->ne[0]);
6659+
}
6660+
6661+
// Tensor split calculation
6662+
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, 1> & tensor_split, int id) {
6663+
// For Metal, we only have one device, so all rows go to device 0
6664+
if (id == 0) {
6665+
*row_low = 0;
6666+
*row_high = tensor->ne[1];
6667+
} else {
6668+
*row_low = 0;
6669+
*row_high = 0;
6670+
}
6671+
6672+
GGML_UNUSED(tensor_split);
6673+
}
6674+
6675+
// Buffer free function
6676+
static void ggml_backend_metal_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
6677+
ggml_backend_metal_split_buffer_context * ctx = (ggml_backend_metal_split_buffer_context *)buffer->context;
6678+
delete ctx;
6679+
}
6680+
6681+
// Buffer get base function
6682+
static void * ggml_backend_metal_split_buffer_get_base(ggml_backend_buffer_t buffer) {
6683+
// The pointers are stored in the tensor extras, this is just a dummy address
6684+
return (void *)0x1000;
6685+
6686+
GGML_UNUSED(buffer);
6687+
}
6688+
6689+
// Buffer init tensor function
6690+
static enum ggml_status ggml_backend_metal_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
6691+
GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
6692+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
6693+
6694+
ggml_backend_metal_split_buffer_context * ctx = (ggml_backend_metal_split_buffer_context *)buffer->context;
6695+
ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft->context;
6696+
6697+
const int64_t ne0 = tensor->ne[0];
6698+
6699+
ggml_tensor_extra_metal * extra = new ggml_tensor_extra_metal{};
6700+
ctx->tensor_extras.push_back(extra);
6701+
6702+
// For Metal, we only have one device
6703+
int id = 0;
6704+
int64_t row_low, row_high;
6705+
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
6706+
6707+
int64_t nrows_split = row_high - row_low;
6708+
if (nrows_split == 0) {
6709+
tensor->extra = extra;
6710+
return GGML_STATUS_SUCCESS;
6711+
}
6712+
6713+
size_t size = ggml_nbytes_split(tensor, nrows_split);
6714+
const size_t original_size = size;
6715+
6716+
// Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6717+
if (ne0 % MATRIX_ROW_PADDING != 0) {
6718+
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6719+
}
6720+
6721+
// Get Metal device context
6722+
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buffer->buft->device->context;
6723+
id<MTLDevice> device = ctx_dev->mtl_device;
6724+
6725+
// Allocate Metal buffer
6726+
extra->data_device[id] = [device newBufferWithLength:size options:MTLResourceStorageModePrivate];
6727+
6728+
// Initialize buffer with zeros
6729+
memset([extra->data_device[id] contents], 0, size);
6730+
6731+
tensor->extra = extra;
6732+
return GGML_STATUS_SUCCESS;
6733+
}
6734+
6735+
// Buffer set tensor function
6736+
static void ggml_backend_metal_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
6737+
// Split tensors must always be set in their entirety at once
6738+
GGML_ASSERT(offset == 0);
6739+
GGML_ASSERT(size == ggml_nbytes(tensor));
6740+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
6741+
6742+
ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft->context;
6743+
6744+
const int64_t ne0 = tensor->ne[0];
6745+
const size_t nb1 = tensor->nb[1];
6746+
ggml_tensor_extra_metal * extra = (ggml_tensor_extra_metal *)tensor->extra;
6747+
6748+
// For Metal, we only have one device
6749+
int id = 0;
6750+
int64_t row_low, row_high;
6751+
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
6752+
6753+
int64_t nrows_split = row_high - row_low;
6754+
if (nrows_split == 0) {
6755+
return;
6756+
}
6757+
6758+
const size_t offset_split = row_low * nb1;
6759+
size_t alloc_size = ggml_nbytes_split(tensor, nrows_split);
6760+
const size_t original_size = alloc_size;
6761+
6762+
// Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6763+
if (ne0 % MATRIX_ROW_PADDING != 0) {
6764+
alloc_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6765+
}
6766+
6767+
const char * buf_host = (const char *)data + offset_split;
6768+
6769+
// Copy data to Metal buffer
6770+
memcpy([extra->data_device[id] contents], buf_host, original_size);
6771+
}
6772+
6773+
// Buffer get tensor function
6774+
static void ggml_backend_metal_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
6775+
// Split tensors must always be retrieved in their entirety at once
6776+
GGML_ASSERT(offset == 0);
6777+
GGML_ASSERT(size == ggml_nbytes(tensor));
6778+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
6779+
6780+
ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft->context;
6781+
6782+
const int64_t ne0 = tensor->ne[0];
6783+
const size_t nb1 = tensor->nb[1];
6784+
ggml_tensor_extra_metal * extra = (ggml_tensor_extra_metal *)tensor->extra;
6785+
6786+
// For Metal, we only have one device
6787+
int id = 0;
6788+
int64_t row_low, row_high;
6789+
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
6790+
6791+
int64_t nrows_split = row_high - row_low;
6792+
if (nrows_split == 0) {
6793+
return;
6794+
}
6795+
6796+
const size_t offset_split = row_low * nb1;
6797+
size_t alloc_size = ggml_nbytes_split(tensor, nrows_split);
6798+
const size_t original_size = alloc_size;
6799+
6800+
// Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6801+
if (ne0 % MATRIX_ROW_PADDING != 0) {
6802+
alloc_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6803+
}
6804+
6805+
char * buf_host = (char *)data + offset_split;
6806+
6807+
// Copy data from Metal buffer
6808+
memcpy(buf_host, [extra->data_device[id] contents], original_size);
6809+
}
6810+
6811+
// Buffer clear function
6812+
static void ggml_backend_metal_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
6813+
GGML_UNUSED(buffer);
6814+
GGML_UNUSED(value);
6815+
// Not implemented for split buffers
6816+
}
6817+
6818+
// Buffer interface
6819+
static const ggml_backend_buffer_i ggml_backend_metal_split_buffer_interface = {
6820+
/* .free_buffer = */ ggml_backend_metal_split_buffer_free_buffer,
6821+
/* .get_base = */ ggml_backend_metal_split_buffer_get_base,
6822+
/* .init_tensor = */ ggml_backend_metal_split_buffer_init_tensor,
6823+
/* .memset_tensor = */ NULL,
6824+
/* .set_tensor = */ ggml_backend_metal_split_buffer_set_tensor,
6825+
/* .get_tensor = */ ggml_backend_metal_split_buffer_get_tensor,
6826+
/* .cpy_tensor = */ NULL,
6827+
/* .clear = */ ggml_backend_metal_split_buffer_clear,
6828+
/* .reset = */ NULL,
6829+
};
6830+
6831+
// Buffer type interface functions
6832+
static const char * ggml_backend_metal_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
6833+
ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context;
6834+
return ctx->name.c_str();
6835+
}
6836+
6837+
static bool ggml_backend_buft_is_metal_split(ggml_backend_buffer_type_t buft) {
6838+
return buft->iface.get_name == ggml_backend_metal_split_buffer_type_get_name;
6839+
}
6840+
6841+
static ggml_backend_buffer_t ggml_backend_metal_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
6842+
// Since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
6843+
// Instead, we allocate them for each tensor separately in init_tensor
6844+
// However, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
6845+
// as returned by get_alloc_size. This limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
6846+
ggml_backend_metal_split_buffer_context * ctx = new ggml_backend_metal_split_buffer_context();
6847+
6848+
return ggml_backend_buffer_init(buft, ggml_backend_metal_split_buffer_interface, ctx, size);
6849+
}
6850+
6851+
static size_t ggml_backend_metal_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
6852+
return 128;
6853+
6854+
GGML_UNUSED(buft);
6855+
}
6856+
6857+
static size_t ggml_backend_metal_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
6858+
ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context;
6859+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
6860+
6861+
size_t total_size = 0;
6862+
6863+
const int64_t ne0 = tensor->ne[0];
6864+
6865+
// For Metal, we only have one device
6866+
int id = 0;
6867+
int64_t row_low, row_high;
6868+
get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
6869+
6870+
int64_t nrows_split = row_high - row_low;
6871+
if (nrows_split == 0) {
6872+
return total_size;
6873+
}
6874+
6875+
total_size += ggml_nbytes_split(tensor, nrows_split);
6876+
6877+
// Pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
6878+
if (ne0 % MATRIX_ROW_PADDING != 0) {
6879+
total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
6880+
}
6881+
6882+
return total_size;
6883+
}
6884+
6885+
static bool ggml_backend_metal_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
6886+
return false;
6887+
6888+
GGML_UNUSED(buft);
6889+
}
6890+
6891+
// Buffer type interface
6892+
static const ggml_backend_buffer_type_i ggml_backend_metal_split_buffer_type_interface = {
6893+
/* .get_name = */ ggml_backend_metal_split_buffer_type_get_name,
6894+
/* .alloc_buffer = */ ggml_backend_metal_split_buffer_type_alloc_buffer,
6895+
/* .get_alignment = */ ggml_backend_metal_split_buffer_type_get_alignment,
6896+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
6897+
/* .get_alloc_size = */ ggml_backend_metal_split_buffer_type_get_alloc_size,
6898+
/* .is_host = */ ggml_backend_metal_split_buffer_type_is_host,
6899+
};
6900+
6901+
#endif // __cplusplus
6902+
6903+
// Main function to create Metal split buffer type
6904+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_split_buffer_type(int main_device, const float * tensor_split) {
6905+
GGML_LOG_INFO("%s: creating Metal split buffer type, main_device=%d\n", __func__, main_device);
6906+
#ifdef __cplusplus
6907+
static std::mutex mutex;
6908+
std::lock_guard<std::mutex> lock(mutex);
6909+
6910+
static std::map<std::pair<int, std::array<float, 1>>, struct ggml_backend_buffer_type> buft_map;
6911+
6912+
std::array<float, 1> tensor_split_arr = {};
6913+
6914+
// For Metal, we only support one device, so we simplify the tensor split logic
6915+
tensor_split_arr[0] = 1.0f; // All tensors go to the single Metal device
6916+
6917+
auto it = buft_map.find({main_device, tensor_split_arr});
6918+
if (it != buft_map.end()) {
6919+
GGML_LOG_INFO("%s: returning existing buffer type\n", __func__);
6920+
return &it->second;
6921+
}
6922+
6923+
auto * ctx = new ggml_backend_metal_split_buffer_type_context{
6924+
main_device,
6925+
tensor_split_arr,
6926+
std::string("Metal_Split"),
6927+
};
6928+
6929+
struct ggml_backend_buffer_type buft {
6930+
/* .iface = */ ggml_backend_metal_split_buffer_type_interface,
6931+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), main_device),
6932+
/* .context = */ ctx,
6933+
};
6934+
6935+
auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
6936+
GGML_LOG_INFO("%s: created new Metal split buffer type\n", __func__);
6937+
return &result.first->second;
6938+
#else
6939+
// For C builds, return the regular Metal buffer type
6940+
GGML_LOG_INFO("%s: C build, returning regular Metal buffer type\n", __func__);
6941+
return ggml_backend_metal_buffer_type();
6942+
#endif
6943+
}
6944+
66026945
// TODO: make thread-safe
66036946
ggml_backend_reg_t ggml_backend_metal_reg(void) {
66046947
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);

0 commit comments

Comments
 (0)