Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3168,11 +3168,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
ggml_sycl_op_diag_mask_inf(ctx, dst);
}

static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
ggml_sycl_op_rope(ctx, dst);
}

static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_pool2d(ctx, dst);
}
Expand Down Expand Up @@ -4002,7 +3997,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (mode == GGML_ROPE_TYPE_MROPE) {
return false;
}
return ggml_is_contiguous(op->src[0]);
return true;
}
case GGML_OP_IM2COL:
return true;
Expand Down
197 changes: 94 additions & 103 deletions ggml/src/ggml-sycl/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,90 +34,92 @@ static void rope_yarn(
*sin_theta = sycl::sin(theta) * mscale;
}

template<typename T, bool has_ff>
static void rope_norm(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
template <typename T, bool has_ff>
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));

if (i0 >= ne0) {
return;
}

const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);

if (i0 >= n_dims) {
const int i = row*ne0 + i0;
const int i = row * ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows;
const int row0 = row % ne1;
const int channel0 = row / ne1;

const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0;

const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;

float cos_theta;
float sin_theta;

rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + 1];
const float x0 = x[i2 + 0];
const float x1 = x[i2 + 1];

dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + 1] = x0*sin_theta + x1*cos_theta;
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
}

template<typename T, bool has_ff>
static void rope_neox(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
template <typename T, bool has_ff>
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));

if (i0 >= ne0) {
return;
}

const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);

if (i0 >= n_dims) {
const int i = row*ne0 + i0;
const int i = row * ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;
const int row0 = row % ne1;
const int channel0 = row / ne1;

const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;

const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;

float cos_theta;
float sin_theta;

rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
const float x0 = x[i2 + 0];
const float x1 = x[i2 + n_dims / 2];

dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
}

template <typename T, bool has_ff>
Expand Down Expand Up @@ -163,80 +165,66 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
}

template <typename T>
static void rope_norm_sycl(
const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nr);

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);

dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });

if (freq_factors == nullptr) {
/*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1);
});
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
} else {
/*
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1);
});
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
}
}

template <typename T>
static void rope_neox_sycl(
const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nr);

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);

dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });

if (freq_factors == nullptr) {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors,
item_ct1);
});
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors,
item_ct1);
});
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
}
}

Expand Down Expand Up @@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
}
}

void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {

GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
Expand Down Expand Up @@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
if (is_neox) {
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_neox_sycl(
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_neox_sycl(
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_vision) {
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
main_stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else {
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_norm_sycl(
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_norm_sycl(
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
} else {
GGML_ABORT("fatal error");
}
}
}

void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_rope(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}

2 changes: 1 addition & 1 deletion ggml/src/ggml-sycl/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@

#include "common.hpp"

void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);

#endif // GGML_SYCL_ROPE_HPP
Loading