@@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
67936793 }
67946794}
67956795
6796+ // ggml_compute_forward_roll
6797+
6798+ static int64_t ggml_wrap_index (int64_t i, int64_t ne) {
6799+ if (i < 0 ) {
6800+ return i + ne;
6801+ } else if (i >= ne) {
6802+ return i - ne;
6803+ }
6804+ return i;
6805+ }
6806+
6807+ static void ggml_compute_forward_roll_f32 (
6808+ const ggml_compute_params * params,
6809+ ggml_tensor * dst) {
6810+
6811+ const ggml_tensor * src0 = dst->src [0 ];
6812+ const float * src_data = (const float *) src0->data ;
6813+ float * dst_data = (float *) dst->data ;
6814+
6815+ GGML_TENSOR_UNARY_OP_LOCALS
6816+
6817+ const int s0 = ggml_get_op_params_i32 (dst, 0 );
6818+ const int s1 = ggml_get_op_params_i32 (dst, 1 );
6819+ const int s2 = ggml_get_op_params_i32 (dst, 2 );
6820+ const int s3 = ggml_get_op_params_i32 (dst, 3 );
6821+
6822+ const int64_t total = ne1 * ne2 * ne3;
6823+ const int64_t per_thread = (total + params->nth ) / params->nth ;
6824+ const int64_t start = params->ith * per_thread;
6825+ const int64_t end = std::min (start + per_thread, total);
6826+
6827+ for (int64_t i = start; i < end; ++i) {
6828+ const int64_t i1 = i % ne1;
6829+ const int64_t i2 = (i / ne1) % ne2;
6830+ const int64_t i3 = i / (ne2 * ne1);
6831+ float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof (float );
6832+
6833+ const int64_t i01 = ggml_wrap_index (i1 - s1, ne01);
6834+ const int64_t i02 = ggml_wrap_index (i2 - s2, ne02);
6835+ const int64_t i03 = ggml_wrap_index (i3 - s3, ne03);
6836+ const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof (float );
6837+
6838+ const int64_t s = ggml_wrap_index (-s0, ne00);
6839+ const int64_t n = ne00 - s;
6840+ ggml_vec_cpy_f32 (n, dst_row, src_row + s);
6841+ ggml_vec_cpy_f32 (s, dst_row + n, src_row);
6842+ }
6843+ }
6844+
6845+ void ggml_compute_forward_roll (
6846+ const ggml_compute_params * params,
6847+ ggml_tensor * dst) {
6848+
6849+ const ggml_tensor * src0 = dst->src [0 ];
6850+
6851+ switch (src0->type ) {
6852+ case GGML_TYPE_F32:
6853+ {
6854+ ggml_compute_forward_roll_f32 (params, dst);
6855+ } break ;
6856+ default :
6857+ {
6858+ GGML_ABORT (" fatal error" );
6859+ }
6860+ }
6861+ }
6862+
67966863// ggml_compute_forward_arange
67976864
67986865static void ggml_compute_forward_arange_f32 (
0 commit comments