Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include <Metal/Metal.h>

#include <stdatomic.h>

#ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0
#endif
Expand All @@ -22,6 +24,9 @@
// overload of MTLGPUFamilyMetal3 (not available in some environments)
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;

// virtual address for GPU memory allocations
static atomic_uintptr_t g_addr_device = 0x000000400ULL;

#if !GGML_METAL_EMBED_LIBRARY
// Here to assist with NSBundle Path Hack
@interface GGMLMetalClass : NSObject
Expand Down Expand Up @@ -827,7 +832,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};

struct ggml_metal_buffer {
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
void * all_data;
size_t all_size;

// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
Expand Down Expand Up @@ -965,14 +970,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
if (shared) {
res->all_data = ggml_metal_host_malloc(size_aligned);
res->is_shared = true;
res->owned = true;
} else {
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
res->all_data = (void *) 0x000000400ULL;
// use virtual address from g_addr_device counter
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
res->is_shared = false;
}
res->all_size = size_aligned;

res->owned = true;

res->device = ggml_metal_device_get_obj(dev);
res->queue = ggml_metal_device_get_queue(dev);

Expand All @@ -983,15 +989,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
res->buffers[0].metal = nil;

if (size_aligned > 0) {
if (props_dev->use_shared_buffers &&shared) {
if (props_dev->use_shared_buffers && shared) {
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
length:size_aligned
options:MTLResourceStorageModeShared
deallocator:nil];
} else {
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];

res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
}
}

Expand Down Expand Up @@ -1139,7 +1143,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {

void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
if (buf->is_shared) {
memset((char *)tensor->data + offset, value, size);
memset((char *) tensor->data + offset, value, size);
return;
}

Expand Down Expand Up @@ -1168,7 +1172,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor

void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
if (buf->is_shared) {
memcpy((char *)tensor->data + offset, data, size);
memcpy((char *) tensor->data + offset, data, size);
return;
}

Expand Down Expand Up @@ -1223,7 +1227,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *

void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
if (buf->is_shared) {
memcpy(data, (const char *)tensor->data + offset, size);
memcpy(data, (const char *) tensor->data + offset, size);
return;
}

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ typedef struct {
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
bool src2;
} ggml_metal_kargs_rope;

typedef struct {
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2969,6 +2969,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
/* sect_1 =*/ sect_1,
/* sect_2 =*/ sect_2,
/* sect_3 =*/ sect_3,
/* src2 =*/ op->src[2] != nullptr,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3748,7 +3748,7 @@ kernel void kernel_rope_norm(

const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);

const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;

rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);

Expand Down Expand Up @@ -3801,7 +3801,7 @@ kernel void kernel_rope_neox(

const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);

const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;

rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);

Expand Down Expand Up @@ -3872,7 +3872,7 @@ kernel void kernel_rope_multi(

const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);

const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;

rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);

Expand Down Expand Up @@ -3939,7 +3939,7 @@ kernel void kernel_rope_vision(
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope

const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;

rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);

Expand Down
Loading