22#include < fstream>
33#include < stdexcept>
44#include < string>
5+ #include < thread>
56#include < unordered_map>
67
78#define GGML_COMMON_IMPL_CPP
@@ -773,6 +774,9 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
773774// for fast testing
774775// #define TMAC_EMPTY_WEIGHTS
775776#ifndef TMAC_EMPTY_WEIGHTS
777+ std::vector<std::thread> threads;
778+ const int n_threads = std::thread::hardware_concurrency ();
779+
776780 // TODO: optimize to accelerate weights loading
777781 uint8_t * buf2 = new uint8_t [m * k / g];
778782 memset (buf2, 0 , m * k / g);
@@ -782,7 +786,9 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
782786 // # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) -> (M // bits, bits, K // g)
783787 // w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g)
784788 // w = sum([(w[:, :, :, ig] << ig) for ig in range(g)])
785- for (int im = 0 ; im < m / bits; im++) {
789+ threads.reserve (n_threads);
790+ auto parallel_worker_buf2 = [&](size_t start_index, size_t end_index) {
791+ for (int im = start_index; im < end_index; im++) {
786792 for (int ik = 0 ; ik < k; ik++) {
787793 uint8_t v;
788794 if (tensor->type == GGML_TYPE_Q4_0) {
@@ -808,6 +814,25 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
808814 }
809815 }
810816 }
817+ };
818+
819+ size_t start_index = 0 ;
820+ size_t chunk_size = m / bits / n_threads;
821+ for (size_t i = 0 ; i < n_threads; ++i) {
822+ size_t end_index = (i == n_threads - 1 ) ? m / bits : start_index + chunk_size;
823+
824+ // Create and launch a thread
825+ threads.emplace_back (parallel_worker_buf2,
826+ start_index,
827+ end_index); // Pass the mutex array by reference
828+
829+ start_index = end_index;
830+ }
831+ // Wait for all threads to complete
832+ for (std::thread& t : threads) {
833+ t.join ();
834+ }
835+ threads.clear ();
811836
812837 // # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
813838 // # for bits=3
@@ -843,7 +868,9 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
843868 int c2_fac1 = bm / mgroup * c2_fac2;
844869 int c2_fac0 = k / g / kfactor * c2_fac1;
845870
846- for (int im = 0 ; im < m / bits; im++) {
871+ threads.reserve (n_threads);
872+ auto parallel_worker_qweights = [&](size_t start_index, size_t end_index) {
873+ for (int im = start_index; im < end_index; im++) {
847874 for (int ib = 0 ; ib < bits; ib++) {
848875 for (int ik = 0 ; ik < k / g; ik++) {
849876 // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
@@ -881,6 +908,25 @@ static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const
881908 }
882909 }
883910 }
911+ };
912+
913+ start_index = 0 ;
914+ chunk_size = m / bits / n_threads;
915+ for (size_t i = 0 ; i < n_threads; ++i) {
916+ size_t end_index = (i == n_threads - 1 ) ? m / bits : start_index + chunk_size;
917+
918+ // Create and launch a thread
919+ threads.emplace_back (parallel_worker_qweights,
920+ start_index,
921+ end_index); // Pass the mutex array by reference
922+
923+ start_index = end_index;
924+ }
925+ // Wait for all threads to complete
926+ for (std::thread& t : threads) {
927+ t.join ();
928+ }
929+ threads.clear ();
884930
885931 const float * int_n_scales = (const float * ) ((const uint8_t *) origin_data + k * m / 8 );
886932 const float * int_n_zero_points = int_n_scales + scales_size / 2 ;
0 commit comments