@@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
22822282 }
22832283}
22842284
2285+ static void ggml_compute_forward_repeat_i64 (
2286+ const ggml_compute_params * params,
2287+ ggml_tensor * dst) {
2288+
2289+ const ggml_tensor * src0 = dst->src [0 ];
2290+
2291+ if (params->ith != 0 ) {
2292+ return ;
2293+ }
2294+
2295+ GGML_ASSERT (ggml_can_repeat (src0, dst));
2296+
2297+ GGML_TENSOR_UNARY_OP_LOCALS
2298+
2299+ // guaranteed to be an integer due to the check in ggml_can_repeat
2300+ const int nr0 = (int )(ne0/ne00);
2301+ const int nr1 = (int )(ne1/ne01);
2302+ const int nr2 = (int )(ne2/ne02);
2303+ const int nr3 = (int )(ne3/ne03);
2304+
2305+ // TODO: support for transposed / permuted tensors
2306+ GGML_ASSERT (nb0 == sizeof (int64_t ));
2307+ GGML_ASSERT (nb00 == sizeof (int64_t ));
2308+
2309+ // TODO: maybe this is not optimal?
2310+ for (int i3 = 0 ; i3 < nr3; i3++) {
2311+ for (int k3 = 0 ; k3 < ne03; k3++) {
2312+ for (int i2 = 0 ; i2 < nr2; i2++) {
2313+ for (int k2 = 0 ; k2 < ne02; k2++) {
2314+ for (int i1 = 0 ; i1 < nr1; i1++) {
2315+ for (int k1 = 0 ; k1 < ne01; k1++) {
2316+ for (int i0 = 0 ; i0 < nr0; i0++) {
2317+ int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
2318+ int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
2319+ for (int i = 0 ; i < ne00; ++i) {
2320+ y[i] = x[i];
2321+ }
2322+ }
2323+ }
2324+ }
2325+ }
2326+ }
2327+ }
2328+ }
2329+ }
2330+
22852331void ggml_compute_forward_repeat (
22862332 const ggml_compute_params * params,
22872333 ggml_tensor * dst) {
@@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
23002346 {
23012347 ggml_compute_forward_repeat_f32 (params, dst);
23022348 } break ;
2349+ case GGML_TYPE_I64:
2350+ {
2351+ ggml_compute_forward_repeat_i64 (params, dst);
2352+ } break ;
23032353 default :
23042354 {
23052355 GGML_ABORT (" fatal error" );
0 commit comments