Skip to content

Commit e4e068a

Browse files
committed
Fixing the build.
1 parent b53f098 commit e4e068a

File tree

1 file changed

+34
-61
lines changed

1 file changed

+34
-61
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 34 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5878,6 +5878,20 @@ static enum ggml_status ggml_metal_graph_compute(
58785878

58795879
// backend interface
58805880

5881+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_split_buffer_type(int main_device, const float * tensor_split) {
5882+
GGML_LOG_INFO("%s: creating Metal split buffer type, main_device=%d\n", __func__, main_device);
5883+
5884+
// For Metal split buffer type, we return the regular Metal buffer type
5885+
// since Metal currently only supports one device
5886+
ggml_backend_buffer_type_t buft = ggml_backend_metal_buffer_type();
5887+
GGML_LOG_INFO("%s: returning Metal buffer type\n", __func__);
5888+
return buft;
5889+
5890+
GGML_UNUSED(main_device);
5891+
GGML_UNUSED(tensor_split);
5892+
}
5893+
5894+
58815895
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
58825896
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
58835897

@@ -6593,7 +6607,7 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
65936607

65946608
static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
65956609
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
6596-
return (void *)ggml_backend_metal_split_buffer_type;
6610+
return (void *)ggml_backend_split_buffer_type;
65976611
}
65986612
if (strcmp(name, "ggml_backend_get_features") == 0) {
65996613
return (void *)ggml_backend_metal_get_features;
@@ -6631,7 +6645,7 @@ static void ggml_metal_cleanup(void) {
66316645
};
66326646

66336647
// Buffer type context
6634-
struct ggml_backend_metal_split_buffer_type_context {
6648+
struct ggml_backend_split_buffer_type_context {
66356649
int main_device;
66366650
std::array<float, 1> tensor_split; // Metal only supports one device, but keeping array for API consistency
66376651
std::string name;
@@ -6660,6 +6674,7 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl
66606674

66616675
// Tensor split calculation
66626676
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, 1> & tensor_split, int id) {
6677+
GGML_LOG_INFO("Returning Row Splits.\n");
66636678
// For Metal, we only have one device, so all rows go to device 0
66646679
if (id == 0) {
66656680
*row_low = 0;
@@ -6692,7 +6707,7 @@ static enum ggml_status ggml_backend_metal_split_buffer_init_tensor(ggml_backend
66926707
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
66936708

66946709
ggml_backend_metal_split_buffer_context * ctx = (ggml_backend_metal_split_buffer_context *)buffer->context;
6695-
ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft->context;
6710+
ggml_backend_split_buffer_type_context * buft_ctx = (ggml_backend_split_buffer_type_context *)buffer->buft->context;
66966711

66976712
const int64_t ne0 = tensor->ne[0];
66986713

@@ -6739,7 +6754,7 @@ static void ggml_backend_metal_split_buffer_set_tensor(ggml_backend_buffer_t buf
67396754
GGML_ASSERT(size == ggml_nbytes(tensor));
67406755
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
67416756

6742-
ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft->context;
6757+
ggml_backend_split_buffer_type_context * buft_ctx = (ggml_backend_split_buffer_type_context *)buffer->buft->context;
67436758

67446759
const int64_t ne0 = tensor->ne[0];
67456760
const size_t nb1 = tensor->nb[1];
@@ -6777,7 +6792,7 @@ static void ggml_backend_metal_split_buffer_get_tensor(ggml_backend_buffer_t buf
67776792
GGML_ASSERT(size == ggml_nbytes(tensor));
67786793
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
67796794

6780-
ggml_backend_metal_split_buffer_type_context * buft_ctx = (ggml_backend_metal_split_buffer_type_context *)buffer->buft->context;
6795+
ggml_backend_split_buffer_type_context * buft_ctx = (ggml_backend_split_buffer_type_context *)buffer->buft->context;
67816796

67826797
const int64_t ne0 = tensor->ne[0];
67836798
const size_t nb1 = tensor->nb[1];
@@ -6829,16 +6844,16 @@ static void ggml_backend_metal_split_buffer_clear(ggml_backend_buffer_t buffer,
68296844
};
68306845

68316846
// Buffer type interface functions
6832-
static const char * ggml_backend_metal_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
6833-
ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context;
6847+
static const char * ggml_backend_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
6848+
ggml_backend_split_buffer_type_context * ctx = (ggml_backend_split_buffer_type_context *)buft->context;
68346849
return ctx->name.c_str();
68356850
}
68366851

68376852
static bool ggml_backend_buft_is_metal_split(ggml_backend_buffer_type_t buft) {
6838-
return buft->iface.get_name == ggml_backend_metal_split_buffer_type_get_name;
6853+
return buft->iface.get_name == ggml_backend_split_buffer_type_get_name;
68396854
}
68406855

6841-
static ggml_backend_buffer_t ggml_backend_metal_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
6856+
static ggml_backend_buffer_t ggml_backend_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
68426857
// Since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
68436858
// Instead, we allocate them for each tensor separately in init_tensor
68446859
// However, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
@@ -6848,14 +6863,14 @@ static ggml_backend_buffer_t ggml_backend_metal_split_buffer_type_alloc_buffer(g
68486863
return ggml_backend_buffer_init(buft, ggml_backend_metal_split_buffer_interface, ctx, size);
68496864
}
68506865

6851-
static size_t ggml_backend_metal_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
6866+
static size_t ggml_backend_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
68526867
return 128;
68536868

68546869
GGML_UNUSED(buft);
68556870
}
68566871

6857-
static size_t ggml_backend_metal_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
6858-
ggml_backend_metal_split_buffer_type_context * ctx = (ggml_backend_metal_split_buffer_type_context *)buft->context;
6872+
static size_t ggml_backend_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
6873+
ggml_backend_split_buffer_type_context * ctx = (ggml_backend_split_buffer_type_context *)buft->context;
68596874
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
68606875

68616876
size_t total_size = 0;
@@ -6882,66 +6897,24 @@ static size_t ggml_backend_metal_split_buffer_type_get_alloc_size(ggml_backend_b
68826897
return total_size;
68836898
}
68846899

6885-
static bool ggml_backend_metal_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
6900+
static bool ggml_backend_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
68866901
return false;
68876902

68886903
GGML_UNUSED(buft);
68896904
}
68906905

68916906
// Buffer type interface
6892-
static const ggml_backend_buffer_type_i ggml_backend_metal_split_buffer_type_interface = {
6893-
/* .get_name = */ ggml_backend_metal_split_buffer_type_get_name,
6894-
/* .alloc_buffer = */ ggml_backend_metal_split_buffer_type_alloc_buffer,
6895-
/* .get_alignment = */ ggml_backend_metal_split_buffer_type_get_alignment,
6907+
static const ggml_backend_buffer_type_i ggml_backend_split_buffer_type_interface = {
6908+
/* .get_name = */ ggml_backend_split_buffer_type_get_name,
6909+
/* .alloc_buffer = */ ggml_backend_split_buffer_type_alloc_buffer,
6910+
/* .get_alignment = */ ggml_backend_split_buffer_type_get_alignment,
68966911
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
6897-
/* .get_alloc_size = */ ggml_backend_metal_split_buffer_type_get_alloc_size,
6898-
/* .is_host = */ ggml_backend_metal_split_buffer_type_is_host,
6912+
/* .get_alloc_size = */ ggml_backend_split_buffer_type_get_alloc_size,
6913+
/* .is_host = */ ggml_backend_split_buffer_type_is_host,
68996914
};
69006915

69016916
#endif // __cplusplus
69026917

6903-
// Main function to create Metal split buffer type
6904-
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_split_buffer_type(int main_device, const float * tensor_split) {
6905-
GGML_LOG_INFO("%s: creating Metal split buffer type, main_device=%d\n", __func__, main_device);
6906-
#ifdef __cplusplus
6907-
static std::mutex mutex;
6908-
std::lock_guard<std::mutex> lock(mutex);
6909-
6910-
static std::map<std::pair<int, std::array<float, 1>>, struct ggml_backend_buffer_type> buft_map;
6911-
6912-
std::array<float, 1> tensor_split_arr = {};
6913-
6914-
// For Metal, we only support one device, so we simplify the tensor split logic
6915-
tensor_split_arr[0] = 1.0f; // All tensors go to the single Metal device
6916-
6917-
auto it = buft_map.find({main_device, tensor_split_arr});
6918-
if (it != buft_map.end()) {
6919-
GGML_LOG_INFO("%s: returning existing buffer type\n", __func__);
6920-
return &it->second;
6921-
}
6922-
6923-
auto * ctx = new ggml_backend_metal_split_buffer_type_context{
6924-
main_device,
6925-
tensor_split_arr,
6926-
std::string("Metal_Split"),
6927-
};
6928-
6929-
struct ggml_backend_buffer_type buft {
6930-
/* .iface = */ ggml_backend_metal_split_buffer_type_interface,
6931-
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), main_device),
6932-
/* .context = */ ctx,
6933-
};
6934-
6935-
auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
6936-
GGML_LOG_INFO("%s: created new Metal split buffer type\n", __func__);
6937-
return &result.first->second;
6938-
#else
6939-
// For C builds, return the regular Metal buffer type
6940-
GGML_LOG_INFO("%s: C build, returning regular Metal buffer type\n", __func__);
6941-
return ggml_backend_metal_buffer_type();
6942-
#endif
6943-
}
6944-
69456918
// TODO: make thread-safe
69466919
ggml_backend_reg_t ggml_backend_metal_reg(void) {
69476920
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);

0 commit comments

Comments
 (0)