diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 360fbe19f0fb6..36650aa9adeb7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -8,6 +8,7 @@ #include #include +#include #ifndef TARGET_OS_VISION #define TARGET_OS_VISION 0 @@ -39,33 +40,153 @@ @implementation GGMLMetalClass // MTLFunctionConstantValues wrapper // +enum ggml_metal_cv_value_type { + GGML_METAL_CV_VALUE_TYPE_INT16, + GGML_METAL_CV_VALUE_TYPE_INT32, + GGML_METAL_CV_VALUE_TYPE_BOOL, +}; + +struct ggml_metal_cv_value { + enum ggml_metal_cv_value_type type; + int32_t idx; + union { + int16_t i16; + int32_t i32; + bool b; + } value; +}; + +#define GGML_METAL_MAX_FUNCTION_CONSTANTS 128 + struct ggml_metal_cv { MTLFunctionConstantValues * obj; + struct ggml_metal_cv_value values[GGML_METAL_MAX_FUNCTION_CONSTANTS]; + size_t value_count; }; +static void ggml_metal_cv_apply_value(ggml_metal_cv_t cv, size_t index) { + if (!cv || !cv->obj || index >= cv->value_count) { + return; + } + + const struct ggml_metal_cv_value * entry = &cv->values[index]; + + switch (entry->type) { + case GGML_METAL_CV_VALUE_TYPE_INT16: + [cv->obj setConstantValue:&entry->value.i16 type:MTLDataTypeShort atIndex:entry->idx]; + break; + case GGML_METAL_CV_VALUE_TYPE_INT32: + [cv->obj setConstantValue:&entry->value.i32 type:MTLDataTypeInt atIndex:entry->idx]; + break; + case GGML_METAL_CV_VALUE_TYPE_BOOL: + [cv->obj setConstantValue:&entry->value.b type:MTLDataTypeBool atIndex:entry->idx]; + break; + } +} + +static void ggml_metal_cv_record_value(ggml_metal_cv_t cv, enum ggml_metal_cv_value_type type, int32_t idx, const void * value) { + if (!cv) { + return; + } + + size_t slot = cv->value_count; + + for (size_t i = 0; i < cv->value_count; ++i) { + if (cv->values[i].idx == idx) { + slot = i; + break; + } + } + + if (slot == cv->value_count) { + if (cv->value_count >= GGML_METAL_MAX_FUNCTION_CONSTANTS) { + GGML_LOG_ERROR("%s: error: exceeded maximum number (%d) of stored Metal function constants\n", __func__, GGML_METAL_MAX_FUNCTION_CONSTANTS); + if (cv->obj) { + switch (type) { + case GGML_METAL_CV_VALUE_TYPE_INT16: + [cv->obj setConstantValue:value type:MTLDataTypeShort atIndex:idx]; + break; + case GGML_METAL_CV_VALUE_TYPE_INT32: + [cv->obj setConstantValue:value type:MTLDataTypeInt atIndex:idx]; + break; + case GGML_METAL_CV_VALUE_TYPE_BOOL: + [cv->obj setConstantValue:value type:MTLDataTypeBool atIndex:idx]; + break; + } + } + return; + } + + slot = cv->value_count++; + cv->values[slot].idx = idx; + } + + cv->values[slot].type = type; + + switch (type) { + case GGML_METAL_CV_VALUE_TYPE_INT16: + cv->values[slot].value.i16 = *(const int16_t *) value; + break; + case GGML_METAL_CV_VALUE_TYPE_INT32: + cv->values[slot].value.i32 = *(const int32_t *) value; + break; + case GGML_METAL_CV_VALUE_TYPE_BOOL: + cv->values[slot].value.b = *(const bool *) value; + break; + } + + ggml_metal_cv_apply_value(cv, slot); +} + +static bool ggml_metal_cv_ensure_constants(ggml_metal_cv_t cv) { + if (!cv) { + return false; + } + + if (cv->obj) { + return true; + } + + cv->obj = [[MTLFunctionConstantValues alloc] init]; + if (!cv->obj) { + GGML_LOG_ERROR("%s: error: failed to allocate Metal function constant values container\n", __func__); + return false; + } + + for (size_t i = 0; i < cv->value_count; ++i) { + ggml_metal_cv_apply_value(cv, i); + } + + return true; +} + ggml_metal_cv_t ggml_metal_cv_init(void) { ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv)); res->obj = [[MTLFunctionConstantValues alloc] init]; + res->value_count = 0; return res; } void ggml_metal_cv_free(ggml_metal_cv_t cv) { + if (!cv) { + return; + } [cv->obj release]; free(cv); } void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) { - [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx]; + ggml_metal_cv_record_value(cv, GGML_METAL_CV_VALUE_TYPE_INT16, idx, &value); } void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) { - [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx]; + ggml_metal_cv_record_value(cv, GGML_METAL_CV_VALUE_TYPE_INT32, idx, &value); } void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { - [cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx]; + ggml_metal_cv_record_value(cv, GGML_METAL_CV_VALUE_TYPE_BOOL, idx, &value); } // @@ -336,8 +457,15 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name); + if (cv && !ggml_metal_cv_ensure_constants(cv)) { + GGML_LOG_WARN("%s: warning: failed to materialize function constants for pipeline: base = '%s', name = '%s'\n", __func__, base, name); + } + id mtl_function; - if (!cv) { + if (!cv || cv->obj == nil) { + if (cv && cv->obj == nil) { + GGML_LOG_WARN("%s: warning: compiling pipeline without function constants: base = '%s', name = '%s'\n", __func__, base, name); + } mtl_function = [lib->obj newFunctionWithName:base_func]; } else { mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];