Skip to content

Commit 996e479

Browse files
committed
metal : add kernel arg structs (wip)
1 parent 6423c65 commit 996e479

File tree

3 files changed

+116
-125
lines changed

3 files changed

+116
-125
lines changed

ggml/src/ggml-common.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,36 @@ typedef struct {
418418
} block_iq4_xs;
419419
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
420420

421+
#if defined(GGML_COMMON_DECL_METAL_KARGS)
422+
typedef struct {
423+
int32_t ne00;
424+
int32_t ne01;
425+
int32_t ne02;
426+
int32_t ne03;
427+
uint64_t nb00;
428+
uint64_t nb01;
429+
uint64_t nb02;
430+
uint64_t nb03;
431+
int32_t ne0;
432+
int32_t ne1;
433+
int32_t ne2;
434+
int32_t ne3;
435+
uint64_t nb0;
436+
uint64_t nb1;
437+
uint64_t nb2;
438+
uint64_t nb3;
439+
int32_t n_past;
440+
int32_t n_dims;
441+
int32_t n_ctx_orig;
442+
float freq_base;
443+
float freq_scale;
444+
float ext_factor;
445+
float attn_factor;
446+
float beta_fast;
447+
float beta_slow;
448+
} ggml_metal_kargs_rope;
449+
#endif
450+
421451
#endif // GGML_COMMON_DECL
422452
#endif // GGML_COMMON_DECL
423453

ggml/src/ggml-metal.m

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#import "ggml-impl.h"
44
#import "ggml-backend-impl.h"
55

6+
#define GGML_COMMON_DECL_C
7+
#define GGML_COMMON_DECL_METAL_KARGS
8+
#include "ggml-common.h"
9+
610
#import <Foundation/Foundation.h>
711

812
#import <Metal/Metal.h>
@@ -2702,40 +2706,44 @@ static void ggml_metal_encode_node(
27022706
};
27032707
}
27042708

2709+
ggml_metal_kargs_rope args = {
2710+
.ne00 = ne00,
2711+
.ne01 = ne01,
2712+
.ne02 = ne02,
2713+
.ne03 = ne03,
2714+
.nb00 = nb00,
2715+
.nb01 = nb01,
2716+
.nb02 = nb02,
2717+
.nb03 = nb03,
2718+
.ne0 = ne0,
2719+
.ne1 = ne1,
2720+
.ne2 = ne2,
2721+
.ne3 = ne3,
2722+
.nb0 = nb0,
2723+
.nb1 = nb1,
2724+
.nb2 = nb2,
2725+
.nb3 = nb3,
2726+
.n_past = n_past,
2727+
.n_dims = n_dims,
2728+
.n_ctx_orig = n_ctx_orig,
2729+
.freq_base = freq_base,
2730+
.freq_scale = freq_scale,
2731+
.ext_factor = ext_factor,
2732+
.attn_factor = attn_factor,
2733+
.beta_fast = beta_fast,
2734+
.beta_slow = beta_slow,
2735+
};
2736+
27052737
[encoder setComputePipelineState:pipeline];
2706-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2707-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2738+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2739+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
27082740
if (id_src2 != nil) {
2709-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2741+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
27102742
} else {
2711-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2743+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
27122744
}
2713-
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2714-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2715-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2716-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2717-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2718-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2719-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2720-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2721-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2722-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2723-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2724-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2725-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2726-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2727-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2728-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2729-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2730-
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2731-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2732-
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2733-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2734-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2735-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2736-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2737-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2738-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2745+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2746+
[encoder setBytes:&args length:sizeof(args) atIndex:4];
27392747

27402748
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
27412749
} break;

ggml/src/ggml-metal.metal

Lines changed: 48 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#define GGML_COMMON_DECL_METAL
2+
#define GGML_COMMON_DECL_METAL_KARGS
23
#define GGML_COMMON_IMPL_METAL
34
#include "ggml-common.h"
45

@@ -2229,7 +2230,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
22292230
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
22302231
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
22312232
static void rope_yarn(
2232-
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
2233+
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
22332234
thread float * cos_theta, thread float * sin_theta) {
22342235
// Get n-d rotational scaling corrected for extrapolation
22352236
float theta_interp = freq_scale * theta_extrap;
@@ -2261,74 +2262,50 @@ static void rope_yarn_corr_dims(
22612262

22622263
template<typename T>
22632264
kernel void kernel_rope_norm(
2264-
device const void * src0,
2265-
device const int32_t * src1,
2266-
device const float * src2,
2267-
device float * dst,
2268-
constant int64_t & ne00,
2269-
constant int64_t & ne01,
2270-
constant int64_t & ne02,
2271-
constant int64_t & ne03,
2272-
constant uint64_t & nb00,
2273-
constant uint64_t & nb01,
2274-
constant uint64_t & nb02,
2275-
constant uint64_t & nb03,
2276-
constant int64_t & ne0,
2277-
constant int64_t & ne1,
2278-
constant int64_t & ne2,
2279-
constant int64_t & ne3,
2280-
constant uint64_t & nb0,
2281-
constant uint64_t & nb1,
2282-
constant uint64_t & nb2,
2283-
constant uint64_t & nb3,
2284-
constant int & n_past,
2285-
constant int & n_dims,
2286-
constant int & n_ctx_orig,
2287-
constant float & freq_base,
2288-
constant float & freq_scale,
2289-
constant float & ext_factor,
2290-
constant float & attn_factor,
2291-
constant float & beta_fast,
2292-
constant float & beta_slow,
2265+
device const char * src0,
2266+
device const char * src1,
2267+
device const char * src2,
2268+
device char * dst,
2269+
constant ggml_metal_kargs_rope & args,
22932270
uint tiitg[[thread_index_in_threadgroup]],
2294-
uint3 tptg[[threads_per_threadgroup]],
2271+
uint3 tptg [[threads_per_threadgroup]],
22952272
uint3 tgpig[[threadgroup_position_in_grid]]) {
2296-
const int64_t i3 = tgpig[2];
2297-
const int64_t i2 = tgpig[1];
2298-
const int64_t i1 = tgpig[0];
2273+
const int i3 = tgpig[2];
2274+
const int i2 = tgpig[1];
2275+
const int i1 = tgpig[0];
22992276

23002277
float corr_dims[2];
2301-
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
2278+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
23022279

2303-
device const int32_t * pos = src1;
2280+
device const int32_t * pos = (device const int32_t *) src1;
23042281

23052282
const float theta_base = (float) pos[i2];
2306-
const float inv_ndims = -1.f/n_dims;
2283+
const float inv_ndims = -1.f/args.n_dims;
23072284

23082285
float cos_theta;
23092286
float sin_theta;
23102287

2311-
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
2312-
if (i0 < n_dims) {
2313-
const int64_t ic = i0/2;
2288+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2289+
if (i0 < args.n_dims) {
2290+
const int ic = i0/2;
23142291

2315-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
2292+
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
23162293

2317-
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
2294+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
23182295

2319-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
2296+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
23202297

2321-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
2322-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2298+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2299+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
23232300

23242301
const float x0 = src[0];
23252302
const float x1 = src[1];
23262303

23272304
dst_data[0] = x0*cos_theta - x1*sin_theta;
23282305
dst_data[1] = x0*sin_theta + x1*cos_theta;
23292306
} else {
2330-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
2331-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2307+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2308+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
23322309

23332310
dst_data[0] = src[0];
23342311
dst_data[1] = src[1];
@@ -2338,74 +2315,50 @@ kernel void kernel_rope_norm(
23382315

23392316
template<typename T>
23402317
kernel void kernel_rope_neox(
2341-
device const void * src0,
2342-
device const int32_t * src1,
2343-
device const float * src2,
2344-
device float * dst,
2345-
constant int64_t & ne00,
2346-
constant int64_t & ne01,
2347-
constant int64_t & ne02,
2348-
constant int64_t & ne03,
2349-
constant uint64_t & nb00,
2350-
constant uint64_t & nb01,
2351-
constant uint64_t & nb02,
2352-
constant uint64_t & nb03,
2353-
constant int64_t & ne0,
2354-
constant int64_t & ne1,
2355-
constant int64_t & ne2,
2356-
constant int64_t & ne3,
2357-
constant uint64_t & nb0,
2358-
constant uint64_t & nb1,
2359-
constant uint64_t & nb2,
2360-
constant uint64_t & nb3,
2361-
constant int & n_past,
2362-
constant int & n_dims,
2363-
constant int & n_ctx_orig,
2364-
constant float & freq_base,
2365-
constant float & freq_scale,
2366-
constant float & ext_factor,
2367-
constant float & attn_factor,
2368-
constant float & beta_fast,
2369-
constant float & beta_slow,
2318+
device const char * src0,
2319+
device const char * src1,
2320+
device const char * src2,
2321+
device char * dst,
2322+
constant ggml_metal_kargs_rope & args,
23702323
uint tiitg[[thread_index_in_threadgroup]],
23712324
uint3 tptg[[threads_per_threadgroup]],
23722325
uint3 tgpig[[threadgroup_position_in_grid]]) {
2373-
const int64_t i3 = tgpig[2];
2374-
const int64_t i2 = tgpig[1];
2375-
const int64_t i1 = tgpig[0];
2326+
const int i3 = tgpig[2];
2327+
const int i2 = tgpig[1];
2328+
const int i1 = tgpig[0];
23762329

23772330
float corr_dims[2];
2378-
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
2331+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
23792332

2380-
device const int32_t * pos = src1;
2333+
device const int32_t * pos = (device const int32_t *) src1;
23812334

23822335
const float theta_base = (float) pos[i2];
2383-
const float inv_ndims = -1.f/n_dims;
2336+
const float inv_ndims = -1.f/args.n_dims;
23842337

23852338
float cos_theta;
23862339
float sin_theta;
23872340

2388-
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
2389-
if (i0 < n_dims) {
2390-
const int64_t ic = i0/2;
2341+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2342+
if (i0 < args.n_dims) {
2343+
const int ic = i0/2;
23912344

2392-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
2345+
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
23932346

2394-
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
2347+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
23952348

2396-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
2349+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
23972350

2398-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
2399-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
2351+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2352+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
24002353

24012354
const float x0 = src[0];
2402-
const float x1 = src[n_dims/2];
2355+
const float x1 = src[args.n_dims/2];
24032356

2404-
dst_data[0] = x0*cos_theta - x1*sin_theta;
2405-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
2357+
dst_data[0] = x0*cos_theta - x1*sin_theta;
2358+
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
24062359
} else {
2407-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
2408-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2360+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2361+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
24092362

24102363
dst_data[0] = src[0];
24112364
dst_data[1] = src[1];

0 commit comments

Comments
 (0)