Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 132 additions & 4 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <Metal/Metal.h>

#include <stdatomic.h>
#include <stdbool.h>

#ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0
Expand Down Expand Up @@ -39,33 +40,153 @@ @implementation GGMLMetalClass
// MTLFunctionConstantValues wrapper
//

enum ggml_metal_cv_value_type {
GGML_METAL_CV_VALUE_TYPE_INT16,
GGML_METAL_CV_VALUE_TYPE_INT32,
GGML_METAL_CV_VALUE_TYPE_BOOL,
};

struct ggml_metal_cv_value {
enum ggml_metal_cv_value_type type;
int32_t idx;
union {
int16_t i16;
int32_t i32;
bool b;
} value;
};

#define GGML_METAL_MAX_FUNCTION_CONSTANTS 128

struct ggml_metal_cv {
MTLFunctionConstantValues * obj;
struct ggml_metal_cv_value values[GGML_METAL_MAX_FUNCTION_CONSTANTS];
size_t value_count;
};

static void ggml_metal_cv_apply_value(ggml_metal_cv_t cv, size_t index) {
if (!cv || !cv->obj || index >= cv->value_count) {
return;
}

const struct ggml_metal_cv_value * entry = &cv->values[index];

switch (entry->type) {
case GGML_METAL_CV_VALUE_TYPE_INT16:
[cv->obj setConstantValue:&entry->value.i16 type:MTLDataTypeShort atIndex:entry->idx];
break;
case GGML_METAL_CV_VALUE_TYPE_INT32:
[cv->obj setConstantValue:&entry->value.i32 type:MTLDataTypeInt atIndex:entry->idx];
break;
case GGML_METAL_CV_VALUE_TYPE_BOOL:
[cv->obj setConstantValue:&entry->value.b type:MTLDataTypeBool atIndex:entry->idx];
break;
}
}

static void ggml_metal_cv_record_value(ggml_metal_cv_t cv, enum ggml_metal_cv_value_type type, int32_t idx, const void * value) {
if (!cv) {
return;
}

size_t slot = cv->value_count;

for (size_t i = 0; i < cv->value_count; ++i) {
if (cv->values[i].idx == idx) {
slot = i;
break;
}
}

if (slot == cv->value_count) {
if (cv->value_count >= GGML_METAL_MAX_FUNCTION_CONSTANTS) {
GGML_LOG_ERROR("%s: error: exceeded maximum number (%d) of stored Metal function constants\n", __func__, GGML_METAL_MAX_FUNCTION_CONSTANTS);
if (cv->obj) {
switch (type) {
case GGML_METAL_CV_VALUE_TYPE_INT16:
[cv->obj setConstantValue:value type:MTLDataTypeShort atIndex:idx];
break;
case GGML_METAL_CV_VALUE_TYPE_INT32:
[cv->obj setConstantValue:value type:MTLDataTypeInt atIndex:idx];
break;
case GGML_METAL_CV_VALUE_TYPE_BOOL:
[cv->obj setConstantValue:value type:MTLDataTypeBool atIndex:idx];
break;
}
}
return;
}

slot = cv->value_count++;
cv->values[slot].idx = idx;
}

cv->values[slot].type = type;

switch (type) {
case GGML_METAL_CV_VALUE_TYPE_INT16:
cv->values[slot].value.i16 = *(const int16_t *) value;
break;
case GGML_METAL_CV_VALUE_TYPE_INT32:
cv->values[slot].value.i32 = *(const int32_t *) value;
break;
case GGML_METAL_CV_VALUE_TYPE_BOOL:
cv->values[slot].value.b = *(const bool *) value;
break;
}

ggml_metal_cv_apply_value(cv, slot);
}

static bool ggml_metal_cv_ensure_constants(ggml_metal_cv_t cv) {
if (!cv) {
return false;
}

if (cv->obj) {
return true;
}

cv->obj = [[MTLFunctionConstantValues alloc] init];
if (!cv->obj) {
GGML_LOG_ERROR("%s: error: failed to allocate Metal function constant values container\n", __func__);
return false;
}

for (size_t i = 0; i < cv->value_count; ++i) {
ggml_metal_cv_apply_value(cv, i);
}

return true;
}

ggml_metal_cv_t ggml_metal_cv_init(void) {
ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv));

res->obj = [[MTLFunctionConstantValues alloc] init];
res->value_count = 0;

return res;
}

void ggml_metal_cv_free(ggml_metal_cv_t cv) {
if (!cv) {
return;
}
[cv->obj release];
free(cv);
}

void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
ggml_metal_cv_record_value(cv, GGML_METAL_CV_VALUE_TYPE_INT16, idx, &value);
}

void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
ggml_metal_cv_record_value(cv, GGML_METAL_CV_VALUE_TYPE_INT32, idx, &value);
}

void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
[cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx];
ggml_metal_cv_record_value(cv, GGML_METAL_CV_VALUE_TYPE_BOOL, idx, &value);
}

//
Expand Down Expand Up @@ -336,8 +457,15 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l

GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);

if (cv && !ggml_metal_cv_ensure_constants(cv)) {
GGML_LOG_WARN("%s: warning: failed to materialize function constants for pipeline: base = '%s', name = '%s'\n", __func__, base, name);
}

id<MTLFunction> mtl_function;
if (!cv) {
if (!cv || cv->obj == nil) {
if (cv && cv->obj == nil) {
GGML_LOG_WARN("%s: warning: compiling pipeline without function constants: base = '%s', name = '%s'\n", __func__, base, name);
}
mtl_function = [lib->obj newFunctionWithName:base_func];
} else {
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
Expand Down