@@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
776776 id += ne00 * (ne01 - ir1);
777777 }
778778 }
779+ } else if (dst->type == GGML_TYPE_I32) {
780+ size_t id = 0 ;
781+ int32_t * dst_ptr = (int32_t *) dst->data ;
782+
783+ for (int i03 = 0 ; i03 < ne03; i03++) {
784+ for (int i02 = 0 ; i02 < ne02; i02++) {
785+ id += ne00 * ir0;
786+ for (int i01 = ir0; i01 < ir1; i01++) {
787+ for (int i00 = 0 ; i00 < ne00; i00++) {
788+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
789+
790+ dst_ptr[id] = *src0_ptr;
791+ id++;
792+ }
793+ }
794+ id += ne00 * (ne01 - ir1);
795+ }
796+ }
779797 } else {
780798 GGML_ABORT (" fatal error" ); // TODO: implement
781799 }
@@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
947965 }
948966 }
949967 }
968+ } else if (dst->type == GGML_TYPE_I32) {
969+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
970+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
971+ i10 += ne00 * ir0;
972+ while (i10 >= ne0) {
973+ i10 -= ne0;
974+ if (++i11 == ne1) {
975+ i11 = 0 ;
976+ if (++i12 == ne2) {
977+ i12 = 0 ;
978+ if (++i13 == ne3) {
979+ i13 = 0 ;
980+ }
981+ }
982+ }
983+ }
984+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
985+ for (int64_t i00 = 0 ; i00 < ne00; i00++) {
986+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
987+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
988+
989+ *(int32_t *) dst_ptr = *(const float *) src0_ptr;
990+
991+ if (++i10 == ne0) {
992+ i10 = 0 ;
993+ if (++i11 == ne1) {
994+ i11 = 0 ;
995+ if (++i12 == ne2) {
996+ i12 = 0 ;
997+ if (++i13 == ne3) {
998+ i13 = 0 ;
999+ }
1000+ }
1001+ }
1002+ }
1003+ }
1004+ }
1005+ i10 += ne00 * (ne01 - ir1);
1006+ while (i10 >= ne0) {
1007+ i10 -= ne0;
1008+ if (++i11 == ne1) {
1009+ i11 = 0 ;
1010+ if (++i12 == ne2) {
1011+ i12 = 0 ;
1012+ if (++i13 == ne3) {
1013+ i13 = 0 ;
1014+ }
1015+ }
1016+ }
1017+ }
1018+ }
1019+ }
1020+ } else {
1021+ GGML_ABORT (" fatal error" ); // TODO: implement
1022+ }
1023+ }
1024+
1025+ static void ggml_compute_forward_dup_i32 (
1026+ const ggml_compute_params * params,
1027+ ggml_tensor * dst) {
1028+
1029+ const ggml_tensor * src0 = dst->src [0 ];
1030+
1031+ GGML_ASSERT (ggml_nelements (dst) == ggml_nelements (src0));
1032+
1033+ GGML_TENSOR_UNARY_OP_LOCALS
1034+
1035+ const int ith = params->ith ; // thread index
1036+ const int nth = params->nth ; // number of threads
1037+
1038+ // parallelize by rows
1039+ const int nr = ne01;
1040+ // number of rows per thread
1041+ const int dr = (nr + nth - 1 ) / nth;
1042+ // row range for this thread
1043+ const int ir0 = dr * ith;
1044+ const int ir1 = MIN (ir0 + dr, nr);
1045+
1046+ // dst counters
1047+
1048+ int64_t i10 = 0 ;
1049+ int64_t i11 = 0 ;
1050+ int64_t i12 = 0 ;
1051+ int64_t i13 = 0 ;
1052+
1053+ // TODO: not optimal, but works
1054+ if (dst->type == GGML_TYPE_F32) {
1055+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
1056+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
1057+ i10 += ne00 * ir0;
1058+ while (i10 >= ne0) {
1059+ i10 -= ne0;
1060+ if (++i11 == ne1) {
1061+ i11 = 0 ;
1062+ if (++i12 == ne2) {
1063+ i12 = 0 ;
1064+ if (++i13 == ne3) {
1065+ i13 = 0 ;
1066+ }
1067+ }
1068+ }
1069+ }
1070+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
1071+ for (int64_t i00 = 0 ; i00 < ne00; i00++) {
1072+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1073+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
1074+
1075+ *(float *) dst_ptr = *(const int32_t *) src0_ptr;
1076+
1077+ if (++i10 == ne0) {
1078+ i10 = 0 ;
1079+ if (++i11 == ne1) {
1080+ i11 = 0 ;
1081+ if (++i12 == ne2) {
1082+ i12 = 0 ;
1083+ if (++i13 == ne3) {
1084+ i13 = 0 ;
1085+ }
1086+ }
1087+ }
1088+ }
1089+ }
1090+ }
1091+ i10 += ne00 * (ne01 - ir1);
1092+ while (i10 >= ne0) {
1093+ i10 -= ne0;
1094+ if (++i11 == ne1) {
1095+ i11 = 0 ;
1096+ if (++i12 == ne2) {
1097+ i12 = 0 ;
1098+ if (++i13 == ne3) {
1099+ i13 = 0 ;
1100+ }
1101+ }
1102+ }
1103+ }
1104+ }
1105+ }
9501106 } else {
9511107 GGML_ABORT (" fatal error" ); // TODO: implement
9521108 }
@@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
11771333 {
11781334 ggml_compute_forward_dup_f32 (params, dst);
11791335 } break ;
1336+ case GGML_TYPE_I32:
1337+ {
1338+ ggml_compute_forward_dup_i32 (params, dst);
1339+ } break ;
11801340 default :
11811341 {
11821342 if (ggml_is_quantized (src0->type ) && dst->type == GGML_TYPE_F32) {
0 commit comments