@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
696696 if (ggml_is_contiguous (dst)) {
697697 // TODO: simplify
698698 if (nb00 == sizeof (float )) {
699- if (dst->type == GGML_TYPE_F32) {
700- size_t id = 0 ;
701- const size_t rs = ne00 * nb00;
702- char * dst_ptr = (char *) dst->data ;
703-
704- for (int i03 = 0 ; i03 < ne03; i03++) {
705- for (int i02 = 0 ; i02 < ne02; i02++) {
706- id += rs * ir0;
707- for (int i01 = ir0; i01 < ir1; i01++) {
708- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
709- memcpy (dst_ptr + id, src0_ptr, rs);
710- id += rs;
711- }
712- id += rs * (ne01 - ir1);
713- }
714- }
715- } else if (ggml_get_type_traits_cpu (dst->type )->from_float ) {
716- ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu (dst->type )->from_float ;
699+ if (ggml_get_type_traits_cpu (dst->type )->from_float ) {
700+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu (dst->type )->from_float ;
717701
718702 size_t id = 0 ;
719703 size_t rs = nb0 * (ne00 / ggml_blck_size (dst->type ));
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
724708 id += rs * ir0;
725709 for (int i01 = ir0; i01 < ir1; i01++) {
726710 const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
727- quantize_row_q (src0_ptr, dst_ptr + id, ne00);
711+ from_float (src0_ptr, dst_ptr + id, ne00);
728712 id += rs;
729713 }
730714 id += rs * (ne01 - ir1);
@@ -2300,6 +2284,12 @@ void ggml_compute_forward_repeat(
23002284 {
23012285 ggml_compute_forward_repeat_f32 (params, dst);
23022286 } break ;
2287+ // TODO: templateify the implemenation and support for I64
2288+ // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
2289+ // case GGML_TYPE_I64:
2290+ // {
2291+ // ggml_compute_forward_repeat_i64(params, dst);
2292+ // } break;
23032293 default :
23042294 {
23052295 GGML_ABORT (" fatal error" );
@@ -4470,6 +4460,74 @@ void ggml_compute_forward_get_rows(
44704460 // }
44714461}
44724462
4463+ static void ggml_compute_forward_set_rows_f32 (
4464+ const ggml_compute_params * params,
4465+ ggml_tensor * dst) {
4466+
4467+ const ggml_tensor * src0 = dst->src [0 ];
4468+ const ggml_tensor * src1 = dst->src [1 ];
4469+
4470+ GGML_TENSOR_BINARY_OP_LOCALS
4471+
4472+ const int64_t nc = ne00;
4473+ const int64_t nr = ne01;
4474+
4475+ assert (ne0 == nc);
4476+ assert (ne2 == ne02);
4477+ assert (ne3 == ne03);
4478+ assert (src0->type == GGML_TYPE_F32);
4479+ assert (ne02 % ne11 == 0 );
4480+ assert (ne03 % ne12 == 0 );
4481+
4482+ const int ith = params->ith ;
4483+ const int nth = params->nth ;
4484+
4485+ // rows per thread
4486+ const int64_t dr = (nr + nth - 1 )/nth;
4487+
4488+ // row range for this thread
4489+ const int64_t ir0 = dr*ith;
4490+ const int64_t ir1 = std::min (ir0 + dr, nr);
4491+
4492+ ggml_from_float_t const from_float = ggml_get_type_traits_cpu (dst->type )->from_float ;
4493+
4494+ for (int64_t i03 = 0 ; i03 < ne03; ++i03) {
4495+ for (int64_t i02 = 0 ; i02 < ne02; ++i02) {
4496+ for (int64_t i = ir0; i < ir1; ++i) {
4497+ const int64_t i12 = i03%ne12;
4498+ const int64_t i11 = i02%ne11;
4499+ const int64_t i10 = i;
4500+
4501+ const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4502+
4503+ GGML_ASSERT (i1 >= 0 && i1 < ne1);
4504+
4505+ from_float (
4506+ (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4507+ ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
4508+ }
4509+ }
4510+ }
4511+ }
4512+
4513+ void ggml_compute_forward_set_rows (
4514+ const ggml_compute_params * params,
4515+ ggml_tensor * dst) {
4516+
4517+ const ggml_tensor * src0 = dst->src [0 ];
4518+
4519+ switch (src0->type ) {
4520+ case GGML_TYPE_F32:
4521+ {
4522+ ggml_compute_forward_set_rows_f32 (params, dst);
4523+ } break ;
4524+ default :
4525+ {
4526+ GGML_ABORT (" src0->type = %d (%s) not supported" , src0->type , ggml_type_name (src0->type ));
4527+ }
4528+ }
4529+ }
4530+
44734531// ggml_compute_forward_get_rows_back
44744532
44754533static void ggml_compute_forward_get_rows_back_f32_f16 (
0 commit comments