@@ -8830,12 +8830,11 @@ static void rope(
88308830 dst[i + 1] = x0*sin_theta + x1*cos_theta;
88318831}
88328832
8833- template<typename T, bool has_pos>
8833+ template<typename T, bool has_pos, bool has_freq_facs >
88348834static void rope_neox(
88358835 const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8836- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
8837- ,
8838- const sycl::nd_item<3> &item_ct1) {
8836+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8837+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
88398838 const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
88408839 item_ct1.get_local_id(1));
88418840
@@ -8863,8 +8862,10 @@ static void rope_neox(
88638862 float cur_rot = inv_ndims * ic - ib;
88648863
88658864 const int p = has_pos ? pos[i2] : 0;
8865+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8866+
88668867 const float theta_base =
8867- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8868+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor ;
88688869
88698870 float cos_theta, sin_theta;
88708871 rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -12413,7 +12414,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
1241312414 const int32_t *pos, float freq_scale,
1241412415 int p_delta_rows, float freq_base, float ext_factor,
1241512416 float attn_factor, rope_corr_dims corr_dims,
12416- dpct::queue_ptr stream) {
12417+ const float * freq_factors, dpct::queue_ptr stream) {
1241712418 GGML_ASSERT(ncols % 2 == 0);
1241812419 const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
1241912420 const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12423,38 +12424,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
1242312424 const float inv_ndims = -1.0f / n_dims;
1242412425
1242512426 if (pos == nullptr) {
12426- /*
12427- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12428- the limit. To get the device limit, query
12429- info::device::max_work_group_size. Adjust the work-group size if needed.
12430- */
1243112427 dpct::has_capability_or_fail(stream->get_device(),
1243212428 {sycl::aspect::fp16});
12433-
12434- stream->parallel_for(
12435- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12436- [=](sycl::nd_item<3> item_ct1) {
12437- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12438- p_delta_rows, ext_factor, attn_factor,
12439- corr_dims, theta_scale, inv_ndims,
12440- item_ct1);
12441- });
12429+ if (freq_factors == nullptr) {
12430+ stream->parallel_for(
12431+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12432+ [=](sycl::nd_item<3> item_ct1) {
12433+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12434+ p_delta_rows, ext_factor, attn_factor,
12435+ corr_dims, theta_scale, inv_ndims, freq_factors,
12436+ item_ct1);
12437+ });
12438+ } else {
12439+ stream->parallel_for(
12440+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12441+ [=](sycl::nd_item<3> item_ct1) {
12442+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12443+ p_delta_rows, ext_factor, attn_factor,
12444+ corr_dims, theta_scale, inv_ndims, freq_factors,
12445+ item_ct1);
12446+ });
12447+ }
1244212448 } else {
12443- /*
12444- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12445- the limit. To get the device limit, query
12446- info::device::max_work_group_size. Adjust the work-group size if needed.
12447- */
1244812449 dpct::has_capability_or_fail(stream->get_device(),
1244912450 {sycl::aspect::fp16});
1245012451
12451- stream->parallel_for(
12452- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12453- [=](sycl::nd_item<3> item_ct1) {
12454- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12455- p_delta_rows, ext_factor, attn_factor,
12456- corr_dims, theta_scale, inv_ndims, item_ct1);
12457- });
12452+ if (freq_factors == nullptr) {
12453+ stream->parallel_for(
12454+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12455+ [=](sycl::nd_item<3> item_ct1) {
12456+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12457+ p_delta_rows, ext_factor, attn_factor,
12458+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12459+ });
12460+ } else {
12461+ stream->parallel_for(
12462+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12463+ [=](sycl::nd_item<3> item_ct1) {
12464+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12465+ p_delta_rows, ext_factor, attn_factor,
12466+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12467+ });
12468+ }
1245812469 }
1245912470}
1246012471
@@ -13986,9 +13997,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1398613997 ggml_tensor *dst, const float *src0_dd,
1398713998 const float *src1_dd, float *dst_dd,
1398813999 const dpct::queue_ptr &main_stream) {
13989- #pragma message("TODO: implement phi3 frequency factors support")
13990- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
13991- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
14000+ const ggml_tensor * src2 = dst->src[2];
1399214001
1399314002 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
1399414003 GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14014,6 +14023,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1401414023 memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1401514024 memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1401614025
14026+ const float * freq_factors = nullptr;
1401714027 const int32_t * pos = nullptr;
1401814028 if ((mode & 1) == 0) {
1401914029 GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14024,6 +14034,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1402414034 const bool is_neox = mode & 2;
1402514035 const bool is_glm = mode & 4;
1402614036
14037+ if (is_neox) {
14038+ pos = (const int32_t *) src1_dd;
14039+
14040+ if (src2 != nullptr) {
14041+ freq_factors = (const float *) src2->data;
14042+ }
14043+ } else {
14044+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14045+ }
14046+
1402714047 rope_corr_dims corr_dims;
1402814048 ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
1402914049
@@ -14035,13 +14055,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1403514055 if (src0->type == GGML_TYPE_F32) {
1403614056 rope_neox_sycl(
1403714057 (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14038- attn_factor, corr_dims, main_stream
14058+ attn_factor, corr_dims, freq_factors, main_stream
1403914059 );
1404014060 } else if (src0->type == GGML_TYPE_F16) {
1404114061 rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
1404214062 ne00, n_dims, nrows, pos, freq_scale, ne01,
1404314063 freq_base, ext_factor, attn_factor, corr_dims,
14044- main_stream);
14064+ freq_factors, main_stream);
1404514065 } else {
1404614066 GGML_ASSERT(false);
1404714067 }
0 commit comments