@@ -820,44 +820,62 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
820820 // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
821821 // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
822822 memset (qweights, 0 , m * k / g / ngroups_per_elem);
823+
824+ int c0_fac2 = k / g;
825+ int c0_fac1 = simd_n_out * c0_fac2;
826+ int c0_fac0 = bits * c0_fac1;
827+
828+ int c1_nb2 = k / g;
829+ int c1_nb1 = simd_n_in * c1_nb2;
830+ int c1_nb0 = ngroups_per_elem * c1_nb1;
831+ int c1_fac2 = k / g;
832+ int c1_fac1 = ngroups_per_elem * c1_fac2;
833+ int c1_fac0 = simd_n_in * c1_fac1;
834+
835+
836+ int c2_nb4 = kfactor;
837+ int c2_nb3 = k / g / kfactor * c2_nb4;
838+ int c2_nb2 = ngroups_per_elem * c2_nb3;
839+ int c2_nb1 = simd_n_in * c2_nb2;
840+ int c2_nb0 = bm / mgroup * c2_nb1;
841+ int c2_fac3 = simd_n_in * ngroups_per_elem;
842+ int c2_fac2 = kfactor * c2_fac3;
843+ int c2_fac1 = bm / mgroup * c2_fac2;
844+ int c2_fac0 = k / g / kfactor * c2_fac1;
845+
823846 for (int im = 0 ; im < m / bits; im++) {
824847 for (int ib = 0 ; ib < bits; ib++) {
825848 for (int ik = 0 ; ik < k / g; ik++) {
849+ // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
826850 int new_im = im / simd_n_out;
827851 int new_isno = im % simd_n_out;
828852 int new_ib = ib;
829853 int new_ik = ik;
830- // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
831- int new_idx = new_im * bits * simd_n_out * k / g + new_ib * simd_n_out * k / g + new_isno * k / g + new_ik;
854+ int new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;
855+
832856 // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
833- int nb2 = k / g;
834- int nb1 = simd_n_in * nb2;
835- int nb0 = ngroups_per_elem * nb1;
836- new_im = new_idx / nb0;
837- int new_ing = (new_idx % nb0) / nb1;
838- int new_isni = (new_idx % nb1) / nb2;
839- new_ik = (new_idx % nb2);
840- new_idx = new_im * ngroups_per_elem * simd_n_in * k / g + new_isni * ngroups_per_elem * k / g + new_ing * k / g + new_ik;
857+ new_im = new_idx / c1_nb0;
858+ int new_ing = (new_idx % c1_nb0) / c1_nb1;
859+ int new_isni = (new_idx % c1_nb1) / c1_nb2;
860+ new_ik = (new_idx % c1_nb2);
861+ new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;
862+
841863 // # 0 1 2 3 4 5
842864 // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
843- int nb4 = kfactor;
844- int nb3 = k / g / kfactor * nb4;
845- nb2 = ngroups_per_elem * nb3;
846- nb1 = simd_n_in * nb2;
847- nb0 = bm / mgroup * nb1;
848- new_im = new_idx / nb0;
849- int new_ibm = (new_idx % nb0) / nb1;
850- new_isni = (new_idx % nb1) / nb2;
851- new_ing = (new_idx % nb2) / nb3;
852- new_ik = (new_idx % nb3) / nb4;
853- int new_ikf = (new_idx % nb4);
854- new_idx = new_im * k / g / kfactor * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
855- new_ik * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
856- new_ibm * kfactor * simd_n_in * ngroups_per_elem +
857- new_ikf * simd_n_in * ngroups_per_elem +
865+ new_im = new_idx / c2_nb0;
866+ int new_ibm = (new_idx % c2_nb0) / c2_nb1;
867+ new_isni = (new_idx % c2_nb1) / c2_nb2;
868+ new_ing = (new_idx % c2_nb2) / c2_nb3;
869+ new_ik = (new_idx % c2_nb3) / c2_nb4;
870+ int new_ikf = (new_idx % c2_nb4);
871+ new_idx = new_im * c2_fac0 +
872+ new_ik * c2_fac1 +
873+ new_ibm * c2_fac2 +
874+ new_ikf * c2_fac3 +
858875 new_isni * ngroups_per_elem +
859876 new_ing;
860877 new_idx = new_idx / ngroups_per_elem;
878+
861879 // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
862880 qweights[new_idx] += buf2[im * bits * k / g + ib * k / g + ik] << (new_ing * g);
863881 }
0 commit comments