@@ -840,47 +840,117 @@ class SketchContainerImpl {
840840 template <typename Batch, typename IsValid>
841841 void PushRowPageImpl (Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
842842 size_t n_features, bool is_dense, IsValid is_valid) {
843- auto thread_columns_ptr = LoadBalance (batch, nnz, n_features, n_threads_, is_valid);
844-
845843 dmlc::OMPException exc;
846- #pragma omp parallel num_threads(n_threads_)
847- {
848- exc.Run ([&]() {
849- auto tid = static_cast <uint32_t >(omp_get_thread_num ());
850- auto const begin = thread_columns_ptr[tid];
851- auto const end = thread_columns_ptr[tid + 1 ];
852-
853- // do not iterate if no columns are assigned to the thread
854- if (begin < end && end <= n_features) {
855- for (size_t ridx = 0 ; ridx < batch.Size (); ++ridx) {
844+ size_t ridx_block_size = batch.Size () / n_threads_ + (batch.Size () % n_threads_ > 0 );
845+ size_t min_ridx_block_size = 1024 ;
846+ if ((n_features < static_cast <size_t >(n_threads_)) &&
847+ (ridx_block_size > min_ridx_block_size)) {
848+ /* Row-wise parallelisation.
849+ */
850+ std::vector<std::set<float >> categories_buff (n_threads_ * n_features);
851+ std::vector<WQSketch> sketches_buff (n_threads_ * n_features);
852+
853+ #pragma omp parallel num_threads(n_threads_)
854+ {
855+ exc.Run ([&]() {
856+ auto tid = static_cast <uint32_t >(omp_get_thread_num ());
857+ WQSketch* sketches_th = sketches_buff.data () + tid * n_features;
858+ std::set<float >* categories_th = categories_buff.data () + tid * n_features;
859+
860+ for (size_t ii = 0 ; ii < n_features; ii++) {
861+ auto n_bins = std::min (static_cast <bst_idx_t >(max_bins_), columns_size_[ii]);
862+ auto eps = 1.0 / (static_cast <float >(n_bins) * WQSketch::kFactor );
863+ sketches_th[ii].Init (columns_size_[ii], eps);
864+ }
865+
866+ size_t ridx_begin = tid * ridx_block_size;
867+ size_t ridx_end = std::min (ridx_begin + ridx_block_size, batch.Size ());
868+ for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
856869 auto const &line = batch.GetLine (ridx);
857870 auto w = weights[ridx + base_rowid];
858871 if (is_dense) {
859- for (size_t ii = begin ; ii < end ; ii++) {
872+ for (size_t ii = 0 ; ii < n_features ; ii++) {
860873 auto elem = line.GetElement (ii);
861874 if (is_valid (elem)) {
862875 if (IsCat (feature_types_, ii)) {
863- categories_ [ii].emplace (elem.value );
876+ categories_th [ii].emplace (elem.value );
864877 } else {
865- sketches_ [ii].Push (elem.value , w);
878+ sketches_th [ii].Push (elem.value , w);
866879 }
867880 }
868881 }
869882 } else {
870- for (size_t i = 0 ; i < line.Size (); ++i ) {
871- auto const & elem = line.GetElement (i );
872- if (is_valid (elem) && elem. column_idx >= begin && elem. column_idx < end ) {
883+ for (size_t ii = 0 ; ii < line.Size (); ++ii ) {
884+ auto elem = line.GetElement (ii );
885+ if (is_valid (elem)) {
873886 if (IsCat (feature_types_, elem.column_idx )) {
874- categories_ [elem.column_idx ].emplace (elem.value );
887+ categories_th [elem.column_idx ].emplace (elem.value );
875888 } else {
876- sketches_ [elem.column_idx ].Push (elem.value , w);
889+ sketches_th [elem.column_idx ].Push (elem.value , w);
877890 }
878891 }
879892 }
880893 }
881894 }
882- }
883- });
895+ #pragma omp barrier
896+
897+ size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0 );
898+ size_t fidx_begin = tid * fidx_block_size;
899+ size_t fidx_end = std::min (fidx_begin + fidx_block_size, n_features);
900+ for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
901+ for (int th = 0 ; th < n_threads_; ++th) {
902+ if (IsCat (feature_types_, ii)) {
903+ categories_[ii].merge (categories_buff[th * n_features + ii]);
904+ } else {
905+ typename WQSketch::SummaryContainer summary;
906+ sketches_buff[th * n_features + ii].GetSummary (&summary);
907+ sketches_[ii].PushSummary (summary);
908+ }
909+ }
910+ }
911+ });
912+ }
913+ } else {
914+ auto thread_columns_ptr = LoadBalance (batch, nnz, n_features, n_threads_, is_valid);
915+ #pragma omp parallel num_threads(n_threads_)
916+ {
917+ exc.Run ([&]() {
918+ auto tid = static_cast <uint32_t >(omp_get_thread_num ());
919+ auto const begin = thread_columns_ptr[tid];
920+ auto const end = thread_columns_ptr[tid + 1 ];
921+
922+ // do not iterate if no columns are assigned to the thread
923+ if (begin < end && end <= n_features) {
924+ for (size_t ridx = 0 ; ridx < batch.Size (); ++ridx) {
925+ auto const &line = batch.GetLine (ridx);
926+ auto w = weights[ridx + base_rowid];
927+ if (is_dense) {
928+ for (size_t ii = begin; ii < end; ii++) {
929+ auto elem = line.GetElement (ii);
930+ if (is_valid (elem)) {
931+ if (IsCat (feature_types_, ii)) {
932+ categories_[ii].emplace (elem.value );
933+ } else {
934+ sketches_[ii].Push (elem.value , w);
935+ }
936+ }
937+ }
938+ } else {
939+ for (size_t i = 0 ; i < line.Size (); ++i) {
940+ auto const &elem = line.GetElement (i);
941+ if (is_valid (elem) && elem.column_idx >= begin && elem.column_idx < end) {
942+ if (IsCat (feature_types_, elem.column_idx )) {
943+ categories_[elem.column_idx ].emplace (elem.value );
944+ } else {
945+ sketches_[elem.column_idx ].Push (elem.value , w);
946+ }
947+ }
948+ }
949+ }
950+ }
951+ }
952+ });
953+ }
884954 }
885955 exc.Rethrow ();
886956 }
0 commit comments