@@ -2713,15 +2713,161 @@ kernel void kernel_rope_neox(
27132713 }
27142714}
27152715
2716+ template <typename T>
2717+ kernel void kernel_rope_multi (
2718+ constant ggml_metal_kargs_rope & args,
2719+ device const char * src0,
2720+ device const char * src1,
2721+ device const char * src2,
2722+ device char * dst,
2723+ ushort tiitg[[thread_index_in_threadgroup]],
2724+ ushort3 tptg [[threads_per_threadgroup]],
2725+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2726+ const int i3 = tgpig[2 ];
2727+ const int i2 = tgpig[1 ];
2728+ const int i1 = tgpig[0 ];
2729+
2730+ float corr_dims[2 ];
2731+ rope_yarn_corr_dims (args.n_dims , args.n_ctx_orig , args.freq_base , args.beta_fast , args.beta_slow , corr_dims);
2732+
2733+ device const int32_t * pos = (device const int32_t *) src1;
2734+
2735+ const float inv_ndims = -1 .f /args.n_dims ;
2736+
2737+ float cos_theta;
2738+ float sin_theta;
2739+
2740+ for (int i0 = 2 *tiitg; i0 < args.ne0 ; i0 += 2 *tptg.x ) {
2741+ if (i0 < args.n_dims ) {
2742+ const int ic = i0/2 ;
2743+
2744+ // mrope theta calculations
2745+ // note: the rest is the same as kernel_rope_neox
2746+ const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3 ;
2747+ const int sec_w01 = args.sect_0 + args.sect_1 ; // end of section 1
2748+ const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2 ; // end of section 2
2749+ const int sector = ic % sect_dims;
2750+
2751+ float theta_base;
2752+ if (sector < args.sect_0 ) {
2753+ theta_base = (float ) pos[i2];
2754+ } else if (sector < sec_w01) {
2755+ theta_base = (float ) pos[i2 + args.ne02 ];
2756+ } else if (sector < sec_w012) {
2757+ theta_base = (float ) pos[i2 + args.ne02 * 2 ];
2758+ } else {
2759+ theta_base = (float ) pos[i2 + args.ne02 * 3 ];
2760+ }
2761+ // end of mrope
2762+
2763+ const float theta = theta_base * pow (args.freq_base , inv_ndims*i0);
2764+
2765+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
2766+
2767+ rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
2768+
2769+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00 );
2770+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0 );
2771+
2772+ const float x0 = src[0 ];
2773+ const float x1 = src[args.n_dims /2 ];
2774+
2775+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
2776+ dst_data[args.n_dims /2 ] = x0*sin_theta + x1*cos_theta;
2777+ } else {
2778+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00 );
2779+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
2780+
2781+ dst_data[0 ] = src[0 ];
2782+ dst_data[1 ] = src[1 ];
2783+ }
2784+ }
2785+ }
2786+
2787+ template <typename T>
2788+ kernel void kernel_rope_vision (
2789+ constant ggml_metal_kargs_rope & args,
2790+ device const char * src0,
2791+ device const char * src1,
2792+ device const char * src2,
2793+ device char * dst,
2794+ ushort tiitg[[thread_index_in_threadgroup]],
2795+ ushort3 tptg [[threads_per_threadgroup]],
2796+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2797+ const int i3 = tgpig[2 ];
2798+ const int i2 = tgpig[1 ];
2799+ const int i1 = tgpig[0 ];
2800+
2801+ float corr_dims[2 ];
2802+ rope_yarn_corr_dims (args.n_dims , args.n_ctx_orig , args.freq_base , args.beta_fast , args.beta_slow , corr_dims);
2803+
2804+ device const int32_t * pos = (device const int32_t *) src1;
2805+
2806+ const float inv_ndims = -1 .f /args.n_dims ;
2807+
2808+ float cos_theta;
2809+ float sin_theta;
2810+
2811+ for (int i0 = 2 *tiitg; i0 < args.ne0 ; i0 += 2 *tptg.x ) {
2812+ if (i0 < 2 *args.n_dims ) { // different from kernel_rope_multi
2813+ const int ic = i0/2 ;
2814+
2815+ // mrope theta calculations (only support 2 dimensions)
2816+ const int sect_dims = args.sect_0 + args.sect_1 ;
2817+ const int sector = ic % sect_dims;
2818+
2819+ float p;
2820+ float theta_base;
2821+ if (sector < args.sect_1 ) {
2822+ p = (float ) sector;
2823+ theta_base = (float ) pos[i2];
2824+ } else {
2825+ p = (float ) sector - args.sect_0 ;
2826+ theta_base = (float ) pos[i2 + args.ne02 ];
2827+ }
2828+
2829+ const float theta = theta_base * pow (args.freq_base , 2 .0f * inv_ndims * p);
2830+ // end of mrope
2831+
2832+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
2833+
2834+ rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
2835+
2836+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00 );
2837+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0 );
2838+
2839+ const float x0 = src[0 ];
2840+ const float x1 = src[args.n_dims ]; // different from kernel_rope_multi
2841+
2842+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
2843+ dst_data[args.n_dims ] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
2844+ } else {
2845+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00 );
2846+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
2847+
2848+ dst_data[0 ] = src[0 ];
2849+ dst_data[1 ] = src[1 ];
2850+ }
2851+ }
2852+ }
2853+
27162854typedef decltype (kernel_rope_norm<float >) kernel_rope_norm_t;
27172855typedef decltype (kernel_rope_neox<float >) kernel_rope_neox_t;
2856+ typedef decltype (kernel_rope_multi<float >) kernel_rope_multi_t;
2857+ typedef decltype (kernel_rope_vision<float >) kernel_rope_vision_t;
27182858
27192859template [[host_name(" kernel_rope_norm_f32" )]] kernel kernel_rope_norm_t kernel_rope_norm<float >;
27202860template [[host_name(" kernel_rope_norm_f16" )]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
27212861
27222862template [[host_name(" kernel_rope_neox_f32" )]] kernel kernel_rope_neox_t kernel_rope_neox<float >;
27232863template [[host_name(" kernel_rope_neox_f16" )]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
27242864
2865+ template [[host_name(" kernel_rope_multi_f32" )]] kernel kernel_rope_multi_t kernel_rope_multi<float >;
2866+ template [[host_name(" kernel_rope_multi_f16" )]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
2867+
2868+ template [[host_name(" kernel_rope_vision_f32" )]] kernel kernel_rope_vision_t kernel_rope_vision<float >;
2869+ template [[host_name(" kernel_rope_vision_f16" )]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
2870+
27252871typedef void (im2col_t )(
27262872 device const float * x,
27272873 device char * dst,
0 commit comments