@@ -58,10 +58,10 @@ static __global__ void rope(
5858 dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
5959}
6060
61- template <typename T, bool has_pos>
61+ template <typename T, bool has_pos, bool has_freq_facs >
6262static __global__ void rope_neox (
6363 const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
64+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
6565) {
6666 const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
6767
@@ -88,7 +88,9 @@ static __global__ void rope_neox(
8888 float cur_rot = inv_ndims * ic - ib;
8989
9090 const int p = has_pos ? pos[i2] : 0 ;
91- const float theta_base = p*freq_scale*powf (theta_scale, col/2 .0f );
91+ const float freq_factor = has_freq_facs ? freq_factors[ic/2 ] : 1 .0f ;
92+
93+ const float theta_base = p*freq_scale*powf (theta_scale, col/2 .0f )/freq_factor;
9294
9395 float cos_theta, sin_theta;
9496 rope_yarn (theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -164,7 +166,7 @@ static void rope_cuda(
164166template <typename T>
165167static void rope_neox_cuda (
166168 const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
167- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
169+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
168170) {
169171 GGML_ASSERT (ncols % 2 == 0 );
170172 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
@@ -175,15 +177,29 @@ static void rope_neox_cuda(
175177 const float inv_ndims = -1 .0f / n_dims;
176178
177179 if (pos == nullptr ) {
178- rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
179- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
180- theta_scale, inv_ndims
181- );
180+ if (freq_factors == nullptr ) {
181+ rope_neox<T, false , false ><<<block_nums, block_dims, 0 , stream>>> (
182+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183+ theta_scale, inv_ndims, freq_factors
184+ );
185+ } else {
186+ rope_neox<T, false , true ><<<block_nums, block_dims, 0 , stream>>> (
187+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188+ theta_scale, inv_ndims, freq_factors
189+ );
190+ }
182191 } else {
183- rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
184- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
185- theta_scale, inv_ndims
186- );
192+ if (freq_factors == nullptr ) {
193+ rope_neox<T, true , false ><<<block_nums, block_dims, 0 , stream>>> (
194+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195+ theta_scale, inv_ndims, freq_factors
196+ );
197+ } else {
198+ rope_neox<T, true , true ><<<block_nums, block_dims, 0 , stream>>> (
199+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200+ theta_scale, inv_ndims, freq_factors
201+ );
202+ }
187203 }
188204}
189205
@@ -214,24 +230,27 @@ static void rope_cuda_f32(
214230
215231static void rope_neox_cuda_f16 (
216232 const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
217- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
233+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
218234
219- rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
235+ rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
220236}
221237
222238static void rope_neox_cuda_f32 (
223239 const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
224- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
240+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
225241) {
226242
227- rope_neox_cuda<float >(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
243+ rope_neox_cuda<float >(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
228244}
229245
230246void ggml_cuda_op_rope (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231247 const ggml_tensor * src0 = dst->src [0 ];
232248 const ggml_tensor * src1 = dst->src [1 ];
249+ const ggml_tensor * src2 = dst->src [2 ];
250+
233251 const float * src0_d = (const float *)src0->data ;
234252 const float * src1_d = (const float *)src1->data ;
253+
235254 float * dst_d = (float *)dst->data ;
236255 cudaStream_t stream = ctx.stream ();
237256
@@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
241260
242261 const int64_t ne00 = src0->ne [0 ];
243262 const int64_t ne01 = src0->ne [1 ];
244- const int64_t ne2 = dst->ne [2 ];
245263 const int64_t nrows = ggml_nrows (src0);
246264
247265 // const int n_past = ((int32_t *) dst->op_params)[0];
@@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
259277 memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
260278 memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
261279
280+ const float * freq_factors = nullptr ;
262281 const int32_t * pos = nullptr ;
263- if ((mode & 1 ) == 0 ) {
264- GGML_ASSERT (src1->type == GGML_TYPE_I32);
265- GGML_ASSERT (src1->ne [0 ] == ne2);
266- pos = (const int32_t *) src1_d;
267- }
268282
269283 const bool is_neox = mode & 2 ;
270284 const bool is_glm = mode & 4 ;
271285
286+ if (is_neox) {
287+ pos = (const int32_t *) src1_d;
288+
289+ if (src2 != nullptr ) {
290+ freq_factors = (const float *) src2->data ;
291+ }
292+ } else {
293+ GGML_ASSERT (src2 == nullptr && " TODO: freq_factors not implemented for !is_neox" );
294+ }
295+
272296 rope_corr_dims corr_dims;
273297 ggml_rope_yarn_corr_dims (n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v );
274298
@@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
280304 if (src0->type == GGML_TYPE_F32) {
281305 rope_neox_cuda_f32 (
282306 (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
283- attn_factor, corr_dims, stream
307+ attn_factor, corr_dims, freq_factors, stream
284308 );
285309 } else if (src0->type == GGML_TYPE_F16) {
286310 rope_neox_cuda_f16 (
287311 (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
288- attn_factor, corr_dims, stream
312+ attn_factor, corr_dims, freq_factors, stream
289313 );
290314 } else {
291315 GGML_ASSERT (false );
0 commit comments