Skip to content

Commit d70dcd2

Browse files
committed
Speedup tensor transformation by ~35%.
1 parent 2a5ed43 commit d70dcd2

File tree

1 file changed

+43
-25
lines changed

1 file changed

+43
-25
lines changed

ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -820,44 +820,62 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
820820
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
821821
// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
822822
memset(qweights, 0, m * k / g / ngroups_per_elem);
823+
824+
int c0_fac2 = k / g;
825+
int c0_fac1 = simd_n_out * c0_fac2;
826+
int c0_fac0 = bits * c0_fac1;
827+
828+
int c1_nb2 = k / g;
829+
int c1_nb1 = simd_n_in * c1_nb2;
830+
int c1_nb0 = ngroups_per_elem * c1_nb1;
831+
int c1_fac2 = k / g;
832+
int c1_fac1 = ngroups_per_elem * c1_fac2;
833+
int c1_fac0 = simd_n_in * c1_fac1;
834+
835+
836+
int c2_nb4 = kfactor;
837+
int c2_nb3 = k / g / kfactor * c2_nb4;
838+
int c2_nb2 = ngroups_per_elem * c2_nb3;
839+
int c2_nb1 = simd_n_in * c2_nb2;
840+
int c2_nb0 = bm / mgroup * c2_nb1;
841+
int c2_fac3 = simd_n_in * ngroups_per_elem;
842+
int c2_fac2 = kfactor * c2_fac3;
843+
int c2_fac1 = bm / mgroup * c2_fac2;
844+
int c2_fac0 = k / g / kfactor * c2_fac1;
845+
823846
for (int im = 0; im < m / bits; im++) {
824847
for (int ib = 0; ib < bits; ib++) {
825848
for (int ik = 0; ik < k / g; ik++) {
849+
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
826850
int new_im = im / simd_n_out;
827851
int new_isno = im % simd_n_out;
828852
int new_ib = ib;
829853
int new_ik = ik;
830-
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
831-
int new_idx = new_im * bits * simd_n_out * k / g + new_ib * simd_n_out * k / g + new_isno * k / g + new_ik;
854+
int new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;
855+
832856
// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
833-
int nb2 = k / g;
834-
int nb1 = simd_n_in * nb2;
835-
int nb0 = ngroups_per_elem * nb1;
836-
new_im = new_idx / nb0;
837-
int new_ing = (new_idx % nb0) / nb1;
838-
int new_isni = (new_idx % nb1) / nb2;
839-
new_ik = (new_idx % nb2);
840-
new_idx = new_im * ngroups_per_elem * simd_n_in * k / g + new_isni * ngroups_per_elem * k / g + new_ing * k / g + new_ik;
857+
new_im = new_idx / c1_nb0;
858+
int new_ing = (new_idx % c1_nb0) / c1_nb1;
859+
int new_isni = (new_idx % c1_nb1) / c1_nb2;
860+
new_ik = (new_idx % c1_nb2);
861+
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;
862+
841863
// # 0 1 2 3 4 5
842864
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
843-
int nb4 = kfactor;
844-
int nb3 = k / g / kfactor * nb4;
845-
nb2 = ngroups_per_elem * nb3;
846-
nb1 = simd_n_in * nb2;
847-
nb0 = bm / mgroup * nb1;
848-
new_im = new_idx / nb0;
849-
int new_ibm = (new_idx % nb0) / nb1;
850-
new_isni = (new_idx % nb1) / nb2;
851-
new_ing = (new_idx % nb2) / nb3;
852-
new_ik = (new_idx % nb3) / nb4;
853-
int new_ikf = (new_idx % nb4);
854-
new_idx = new_im * k / g / kfactor * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
855-
new_ik * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
856-
new_ibm * kfactor * simd_n_in * ngroups_per_elem +
857-
new_ikf * simd_n_in * ngroups_per_elem +
865+
new_im = new_idx / c2_nb0;
866+
int new_ibm = (new_idx % c2_nb0) / c2_nb1;
867+
new_isni = (new_idx % c2_nb1) / c2_nb2;
868+
new_ing = (new_idx % c2_nb2) / c2_nb3;
869+
new_ik = (new_idx % c2_nb3) / c2_nb4;
870+
int new_ikf = (new_idx % c2_nb4);
871+
new_idx = new_im * c2_fac0 +
872+
new_ik * c2_fac1 +
873+
new_ibm * c2_fac2 +
874+
new_ikf * c2_fac3 +
858875
new_isni * ngroups_per_elem +
859876
new_ing;
860877
new_idx = new_idx / ngroups_per_elem;
878+
861879
// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
862880
qweights[new_idx] += buf2[im * bits * k / g + ib * k / g + ik] << (new_ing * g);
863881
}

0 commit comments

Comments
 (0)