@@ -4,6 +4,11 @@ struct rope_corr_dims {
44 float v[2 ];
55};
66
7+
8+ struct mrope_sections {
9+ int v[4 ];
10+ };
11+
712static __device__ float rope_yarn_ramp (const float low, const float high, const int i0) {
813 const float y = (i0 / 2 - low) / max (0 .001f , high - low);
914 return 1 .0f - min (1 .0f , max (0 .0f , y));
@@ -108,6 +113,105 @@ static __global__ void rope_neox(
108113 dst[i + n_dims/2 ] = x0*sin_theta + x1*cos_theta;
109114}
110115
116+ template <typename T, bool has_ff>
117+ static __global__ void rope_multi (
118+ const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
119+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
120+ const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
121+
122+ if (i0 >= ne0) {
123+ return ;
124+ }
125+
126+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
127+
128+ if (i0 >= n_dims) {
129+ const int i = row*ne0 + i0;
130+
131+ dst[i + 0 ] = x[i + 0 ];
132+ dst[i + 1 ] = x[i + 1 ];
133+
134+ return ;
135+ }
136+
137+ const int i = row*ne0 + i0/2 ;
138+ const int i2 = row/p_delta_rows;
139+
140+ int sect_dims = sections.v [0 ] + sections.v [1 ] + sections.v [2 ] + sections.v [3 ];
141+ int sec_w = sections.v [1 ] + sections.v [0 ];
142+ int sector = (i0 / 2 ) % sect_dims;
143+
144+ float theta_base = 0.0 ;
145+ if (sector < sections.v [0 ]) {
146+ theta_base = pos[i2]*powf (theta_scale, i0/2 .0f );
147+ }
148+ else if (sector >= sections.v [0 ] && sector < sec_w) {
149+ theta_base = pos[i2 + ne2 * 1 ]*powf (theta_scale, i0/2 .0f );
150+ }
151+ else if (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
152+ theta_base = pos[i2 + ne2 * 2 ]*powf (theta_scale, i0/2 .0f );
153+ }
154+ else if (sector >= sec_w + sections.v [2 ]) {
155+ theta_base = pos[i2 + ne2 * 3 ]*powf (theta_scale, i0/2 .0f );
156+ }
157+
158+ const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
159+
160+ float cos_theta;
161+ float sin_theta;
162+
163+ rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
164+
165+ const float x0 = x[i + 0 ];
166+ const float x1 = x[i + n_dims/2 ];
167+
168+ dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
169+ dst[i + n_dims/2 ] = x0*sin_theta + x1*cos_theta;
170+ }
171+
172+ template <typename T, bool has_ff>
173+ static __global__ void rope_vision (
174+ const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
175+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
176+ const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
177+
178+ if (i0 >= ne0) {
179+ return ;
180+ }
181+
182+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
183+
184+ const int i = row*ne0 + i0/2 ;
185+ const int i2 = row/p_delta_rows; // i2-th tokens
186+
187+ int sect_dims = sections.v [0 ] + sections.v [1 ];
188+ int sec_w = sections.v [1 ] + sections.v [0 ];
189+ int sector = (i0 / 2 ) % sect_dims;
190+
191+ float theta_base = 0.0 ;
192+ if (sector < sections.v [0 ]) {
193+ const int p = sector;
194+ theta_base = pos[i2]*powf (theta_scale, p);
195+ }
196+ else if (sector >= sections.v [0 ] && sector < sec_w) {
197+ const int p = sector - sections.v [0 ];
198+ theta_base = pos[i2 + ne2]*powf (theta_scale, p);
199+ }
200+
201+ const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
202+
203+ float cos_theta;
204+ float sin_theta;
205+
206+ rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
207+
208+ const float x0 = x[i + 0 ];
209+ const float x1 = x[i + n_dims];
210+
211+ dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
212+ dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
213+ }
214+
111215template <typename T>
112216static void rope_norm_cuda (
113217 const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -156,6 +260,56 @@ static void rope_neox_cuda(
156260 }
157261}
158262
263+ template <typename T>
264+ static void rope_multi_cuda (
265+ const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
266+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
267+ GGML_ASSERT (ne0 % 2 == 0 );
268+ const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
269+ const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
270+ const dim3 block_nums (nr, n_blocks_x, 1 );
271+
272+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
273+
274+ if (freq_factors == nullptr ) {
275+ rope_multi<T, false ><<<block_nums, block_dims, 0 , stream>>> (
276+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
277+ theta_scale, freq_factors, sections
278+ );
279+ } else {
280+ rope_multi<T, true ><<<block_nums, block_dims, 0 , stream>>> (
281+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
282+ theta_scale, freq_factors, sections
283+ );
284+ }
285+ }
286+
287+ template <typename T>
288+ static void rope_vision_cuda (
289+ const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
290+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
291+ GGML_ASSERT (ne0 % 2 == 0 );
292+ const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
293+ const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
294+ const dim3 block_nums (nr, n_blocks_x, 1 );
295+ // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
296+ // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
297+
298+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
299+
300+ if (freq_factors == nullptr ) {
301+ rope_vision<T, false ><<<block_nums, block_dims, 0 , stream>>> (
302+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
303+ theta_scale, freq_factors, sections
304+ );
305+ } else {
306+ rope_vision<T, true ><<<block_nums, block_dims, 0 , stream>>> (
307+ x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
308+ theta_scale, freq_factors, sections
309+ );
310+ }
311+ }
312+
159313static void rope_norm_cuda_f16 (
160314 const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
161315 float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -185,6 +339,38 @@ static void rope_neox_cuda_f32(
185339 rope_neox_cuda<float >(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
186340}
187341
342+ static void rope_multi_cuda_f16 (
343+ const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
344+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
345+ ) {
346+
347+ rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
348+ }
349+
350+ static void rope_multi_cuda_f32 (
351+ const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
352+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
353+ ) {
354+
355+ rope_multi_cuda<float >(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
356+ }
357+
358+ static void rope_vision_cuda_f16 (
359+ const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
360+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
361+ ) {
362+
363+ rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
364+ }
365+
366+ static void rope_vision_cuda_f32 (
367+ const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
368+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
369+ ) {
370+
371+ rope_vision_cuda<float >(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
372+ }
373+
188374void ggml_cuda_op_rope (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
189375 const ggml_tensor * src0 = dst->src [0 ];
190376 const ggml_tensor * src1 = dst->src [1 ];
@@ -201,15 +387,17 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
201387 GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202388 GGML_ASSERT (src0->type == dst->type );
203389
204- const int64_t ne00 = src0->ne [0 ];
205- const int64_t ne01 = src0->ne [1 ];
390+ const int64_t ne00 = src0->ne [0 ]; // head dims
391+ const int64_t ne01 = src0->ne [1 ]; // num heads
392+ const int64_t ne02 = src0->ne [2 ]; // num heads
206393 const int64_t nr = ggml_nrows (src0);
207394
208395 // const int n_past = ((int32_t *) dst->op_params)[0];
209396 const int n_dims = ((int32_t *) dst->op_params )[1 ];
210397 const int mode = ((int32_t *) dst->op_params )[2 ];
211398 // const int n_ctx = ((int32_t *) dst->op_params)[3];
212399 const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
400+ mrope_sections sections;
213401
214402 // RoPE alteration for extended context
215403 float freq_base;
@@ -225,8 +413,19 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
225413 memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
226414 memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
227415 memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
416+ memcpy (§ions.v , (int32_t *) dst->op_params + 11 , sizeof (int )*4 );
228417
229418 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
419+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
420+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
421+
422+ if (is_mrope) {
423+ GGML_ASSERT (sections.v [0 ] > 0 || sections.v [1 ] > 0 || sections.v [2 ] > 0 );
424+ }
425+
426+ if (is_vision) {
427+ GGML_ASSERT (n_dims == ne00/2 );
428+ }
230429
231430 const int32_t * pos = (const int32_t *) src1_d;
232431
@@ -253,6 +452,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253452 } else {
254453 GGML_ABORT (" fatal error" );
255454 }
455+ } else if (is_mrope && !is_vision) {
456+ if (src0->type == GGML_TYPE_F32) {
457+ rope_multi_cuda_f32 (
458+ (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
459+ attn_factor, corr_dims, freq_factors, sections, stream
460+ );
461+ } else if (src0->type == GGML_TYPE_F16) {
462+ rope_multi_cuda_f16 (
463+ (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
464+ attn_factor, corr_dims, freq_factors, sections, stream
465+ );
466+ } else {
467+ GGML_ABORT (" fatal error" );
468+ }
469+ } else if (is_vision) {
470+ if (src0->type == GGML_TYPE_F32) {
471+ rope_vision_cuda_f32 (
472+ (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
473+ attn_factor, corr_dims, freq_factors, sections, stream
474+ );
475+ } else if (src0->type == GGML_TYPE_F16) {
476+ rope_vision_cuda_f16 (
477+ (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
478+ attn_factor, corr_dims, freq_factors, sections, stream
479+ );
480+ } else {
481+ GGML_ABORT (" fatal error" );
482+ }
256483 } else {
257484 if (src0->type == GGML_TYPE_F32) {
258485 rope_norm_cuda_f32 (
0 commit comments