11#define GGML_COMMON_DECL_METAL
2+ #define GGML_COMMON_DECL_METAL_KARGS
23#define GGML_COMMON_IMPL_METAL
34#include " ggml-common.h"
45
@@ -2229,7 +2230,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
22292230// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
22302231// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
22312232static void rope_yarn (
2232- float theta_extrap, float freq_scale, float corr_dims[2 ], int64_t i0, float ext_factor, float mscale,
2233+ float theta_extrap, float freq_scale, float corr_dims[2 ], int i0, float ext_factor, float mscale,
22332234 thread float * cos_theta, thread float * sin_theta) {
22342235 // Get n-d rotational scaling corrected for extrapolation
22352236 float theta_interp = freq_scale * theta_extrap;
@@ -2261,74 +2262,50 @@ static void rope_yarn_corr_dims(
22612262
22622263template <typename T>
22632264kernel void kernel_rope_norm (
2264- device const void * src0,
2265- device const int32_t * src1,
2266- device const float * src2,
2267- device float * dst,
2268- constant int64_t & ne00,
2269- constant int64_t & ne01,
2270- constant int64_t & ne02,
2271- constant int64_t & ne03,
2272- constant uint64_t & nb00,
2273- constant uint64_t & nb01,
2274- constant uint64_t & nb02,
2275- constant uint64_t & nb03,
2276- constant int64_t & ne0,
2277- constant int64_t & ne1,
2278- constant int64_t & ne2,
2279- constant int64_t & ne3,
2280- constant uint64_t & nb0,
2281- constant uint64_t & nb1,
2282- constant uint64_t & nb2,
2283- constant uint64_t & nb3,
2284- constant int & n_past,
2285- constant int & n_dims,
2286- constant int & n_ctx_orig,
2287- constant float & freq_base,
2288- constant float & freq_scale,
2289- constant float & ext_factor,
2290- constant float & attn_factor,
2291- constant float & beta_fast,
2292- constant float & beta_slow,
2265+ device const char * src0,
2266+ device const char * src1,
2267+ device const char * src2,
2268+ device char * dst,
2269+ constant ggml_metal_kargs_rope & args,
22932270 uint tiitg[[thread_index_in_threadgroup]],
2294- uint3 tptg[[threads_per_threadgroup]],
2271+ uint3 tptg [[threads_per_threadgroup]],
22952272 uint3 tgpig[[threadgroup_position_in_grid]]) {
2296- const int64_t i3 = tgpig[2 ];
2297- const int64_t i2 = tgpig[1 ];
2298- const int64_t i1 = tgpig[0 ];
2273+ const int i3 = tgpig[2 ];
2274+ const int i2 = tgpig[1 ];
2275+ const int i1 = tgpig[0 ];
22992276
23002277 float corr_dims[2 ];
2301- rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
2278+ rope_yarn_corr_dims (args. n_dims , args. n_ctx_orig , args. freq_base , args. beta_fast , args. beta_slow , corr_dims);
23022279
2303- device const int32_t * pos = src1;
2280+ device const int32_t * pos = (device const int32_t *) src1;
23042281
23052282 const float theta_base = (float ) pos[i2];
2306- const float inv_ndims = -1 .f /n_dims;
2283+ const float inv_ndims = -1 .f /args. n_dims ;
23072284
23082285 float cos_theta;
23092286 float sin_theta;
23102287
2311- for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg.x ) {
2312- if (i0 < n_dims) {
2313- const int64_t ic = i0/2 ;
2288+ for (int i0 = 2 *tiitg; i0 < args. ne0 ; i0 += 2 *tptg.x ) {
2289+ if (i0 < args. n_dims ) {
2290+ const int ic = i0/2 ;
23142291
2315- const float theta = theta_base * pow (freq_base, inv_ndims*i0);
2292+ const float theta = theta_base * pow (args. freq_base , inv_ndims*i0);
23162293
2317- const float freq_factor = src2 != src0 ? src2[ic] : 1 .0f ;
2294+ const float freq_factor = src2 != src0 ? ((device const float *) src2) [ic] : 1 .0f ;
23182295
2319- rope_yarn (theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
2296+ rope_yarn (theta/freq_factor, args. freq_scale , corr_dims, i0, args. ext_factor , args. attn_factor , &cos_theta, &sin_theta);
23202297
2321- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
2322- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2298+ device const T * const src = (device T *)(src0 + i3*args. nb03 + i2*args. nb02 + i1*args. nb01 + i0*args. nb00 );
2299+ device T * dst_data = (device T *)( dst + i3*args. nb3 + i2*args. nb2 + i1*args. nb1 + i0*args. nb0 );
23232300
23242301 const float x0 = src[0 ];
23252302 const float x1 = src[1 ];
23262303
23272304 dst_data[0 ] = x0*cos_theta - x1*sin_theta;
23282305 dst_data[1 ] = x0*sin_theta + x1*cos_theta;
23292306 } else {
2330- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
2331- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2307+ device const T * const src = (device T *)(src0 + i3*args. nb03 + i2*args. nb02 + i1*args. nb01 + i0*args. nb00 );
2308+ device T * dst_data = (device T *)( dst + i3*args. nb3 + i2*args. nb2 + i1*args. nb1 + i0*args. nb0 );
23322309
23332310 dst_data[0 ] = src[0 ];
23342311 dst_data[1 ] = src[1 ];
@@ -2338,74 +2315,50 @@ kernel void kernel_rope_norm(
23382315
23392316template <typename T>
23402317kernel void kernel_rope_neox (
2341- device const void * src0,
2342- device const int32_t * src1,
2343- device const float * src2,
2344- device float * dst,
2345- constant int64_t & ne00,
2346- constant int64_t & ne01,
2347- constant int64_t & ne02,
2348- constant int64_t & ne03,
2349- constant uint64_t & nb00,
2350- constant uint64_t & nb01,
2351- constant uint64_t & nb02,
2352- constant uint64_t & nb03,
2353- constant int64_t & ne0,
2354- constant int64_t & ne1,
2355- constant int64_t & ne2,
2356- constant int64_t & ne3,
2357- constant uint64_t & nb0,
2358- constant uint64_t & nb1,
2359- constant uint64_t & nb2,
2360- constant uint64_t & nb3,
2361- constant int & n_past,
2362- constant int & n_dims,
2363- constant int & n_ctx_orig,
2364- constant float & freq_base,
2365- constant float & freq_scale,
2366- constant float & ext_factor,
2367- constant float & attn_factor,
2368- constant float & beta_fast,
2369- constant float & beta_slow,
2318+ device const char * src0,
2319+ device const char * src1,
2320+ device const char * src2,
2321+ device char * dst,
2322+ constant ggml_metal_kargs_rope & args,
23702323 uint tiitg[[thread_index_in_threadgroup]],
23712324 uint3 tptg[[threads_per_threadgroup]],
23722325 uint3 tgpig[[threadgroup_position_in_grid]]) {
2373- const int64_t i3 = tgpig[2 ];
2374- const int64_t i2 = tgpig[1 ];
2375- const int64_t i1 = tgpig[0 ];
2326+ const int i3 = tgpig[2 ];
2327+ const int i2 = tgpig[1 ];
2328+ const int i1 = tgpig[0 ];
23762329
23772330 float corr_dims[2 ];
2378- rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
2331+ rope_yarn_corr_dims (args. n_dims , args. n_ctx_orig , args. freq_base , args. beta_fast , args. beta_slow , corr_dims);
23792332
2380- device const int32_t * pos = src1;
2333+ device const int32_t * pos = (device const int32_t *) src1;
23812334
23822335 const float theta_base = (float ) pos[i2];
2383- const float inv_ndims = -1 .f /n_dims;
2336+ const float inv_ndims = -1 .f /args. n_dims ;
23842337
23852338 float cos_theta;
23862339 float sin_theta;
23872340
2388- for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg.x ) {
2389- if (i0 < n_dims) {
2390- const int64_t ic = i0/2 ;
2341+ for (int i0 = 2 *tiitg; i0 < args. ne0 ; i0 += 2 *tptg.x ) {
2342+ if (i0 < args. n_dims ) {
2343+ const int ic = i0/2 ;
23912344
2392- const float theta = theta_base * pow (freq_base, inv_ndims*i0);
2345+ const float theta = theta_base * pow (args. freq_base , inv_ndims*i0);
23932346
2394- const float freq_factor = src2 != src0 ? src2[ic] : 1 .0f ;
2347+ const float freq_factor = src2 != src0 ? ((device const float *) src2) [ic] : 1 .0f ;
23952348
2396- rope_yarn (theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
2349+ rope_yarn (theta/freq_factor, args. freq_scale , corr_dims, i0, args. ext_factor , args. attn_factor , &cos_theta, &sin_theta);
23972350
2398- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
2399- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
2351+ device const T * const src = (device T *)(src0 + i3*args. nb03 + i2*args. nb02 + i1*args. nb01 + ic*args. nb00 );
2352+ device T * dst_data = (device T *)( dst + i3*args. nb3 + i2*args. nb2 + i1*args. nb1 + ic*args. nb0 );
24002353
24012354 const float x0 = src[0 ];
2402- const float x1 = src[n_dims/2 ];
2355+ const float x1 = src[args. n_dims /2 ];
24032356
2404- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
2405- dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
2357+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
2358+ dst_data[args. n_dims /2 ] = x0*sin_theta + x1*cos_theta;
24062359 } else {
2407- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
2408- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2360+ device const T * const src = (device T *)(src0 + i3*args. nb03 + i2*args. nb02 + i1*args. nb01 + i0*args. nb00 );
2361+ device T * dst_data = (device T *)( dst + i3*args. nb3 + i2*args. nb2 + i1*args. nb1 + i0*args. nb0 );
24092362
24102363 dst_data[0 ] = src[0 ];
24112364 dst_data[1 ] = src[1 ];
0 commit comments