@@ -7648,6 +7648,73 @@ void ggml_compute_forward_pad_reflect_1d(
76487648 }
76497649}
76507650
7651+ // ggml_compute_forward_roll
7652+
7653+ static int64_t ggml_wrap_index (int64_t i, int64_t ne) {
7654+ if (i < 0 ) {
7655+ return i + ne;
7656+ } else if (i >= ne) {
7657+ return i - ne;
7658+ }
7659+ return i;
7660+ }
7661+
7662+ static void ggml_compute_forward_roll_f32 (
7663+ const ggml_compute_params * params,
7664+ ggml_tensor * dst) {
7665+
7666+ const ggml_tensor * src0 = dst->src [0 ];
7667+ const float * src_data = (const float *) src0->data ;
7668+ float * dst_data = (float *) dst->data ;
7669+
7670+ GGML_TENSOR_UNARY_OP_LOCALS
7671+
7672+ const int s0 = ggml_get_op_params_i32 (dst, 0 );
7673+ const int s1 = ggml_get_op_params_i32 (dst, 1 );
7674+ const int s2 = ggml_get_op_params_i32 (dst, 2 );
7675+ const int s3 = ggml_get_op_params_i32 (dst, 3 );
7676+
7677+ const int64_t total = ne1 * ne2 * ne3;
7678+ const int64_t per_thread = (total + params->nth ) / params->nth ;
7679+ const int64_t start = params->ith * per_thread;
7680+ const int64_t end = std::min (start + per_thread, total);
7681+
7682+ for (int64_t i = start; i < end; ++i) {
7683+ const int64_t i1 = i % ne1;
7684+ const int64_t i2 = (i / ne1) % ne2;
7685+ const int64_t i3 = i / (ne2 * ne1);
7686+ float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof (float );
7687+
7688+ const int64_t i01 = ggml_wrap_index (i1 - s1, ne01);
7689+ const int64_t i02 = ggml_wrap_index (i2 - s2, ne02);
7690+ const int64_t i03 = ggml_wrap_index (i3 - s3, ne03);
7691+ const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof (float );
7692+
7693+ const int64_t s = ggml_wrap_index (-s0, ne00);
7694+ const int64_t n = ne00 - s;
7695+ ggml_vec_cpy_f32 (n, dst_row, src_row + s);
7696+ ggml_vec_cpy_f32 (s, dst_row + n, src_row);
7697+ }
7698+ }
7699+
7700+ void ggml_compute_forward_roll (
7701+ const ggml_compute_params * params,
7702+ ggml_tensor * dst) {
7703+
7704+ const ggml_tensor * src0 = dst->src [0 ];
7705+
7706+ switch (src0->type ) {
7707+ case GGML_TYPE_F32:
7708+ {
7709+ ggml_compute_forward_roll_f32 (params, dst);
7710+ } break ;
7711+ default :
7712+ {
7713+ GGML_ABORT (" fatal error" );
7714+ }
7715+ }
7716+ }
7717+
76517718// ggml_compute_forward_arange
76527719
76537720static void ggml_compute_forward_arange_f32 (
0 commit comments