@@ -840,47 +840,117 @@ class SketchContainerImpl {
840
840
template <typename Batch, typename IsValid>
841
841
void PushRowPageImpl (Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
842
842
size_t n_features, bool is_dense, IsValid is_valid) {
843
- auto thread_columns_ptr = LoadBalance (batch, nnz, n_features, n_threads_, is_valid);
844
-
845
843
dmlc::OMPException exc;
846
- #pragma omp parallel num_threads(n_threads_)
847
- {
848
- exc.Run ([&]() {
849
- auto tid = static_cast <uint32_t >(omp_get_thread_num ());
850
- auto const begin = thread_columns_ptr[tid];
851
- auto const end = thread_columns_ptr[tid + 1 ];
852
-
853
- // do not iterate if no columns are assigned to the thread
854
- if (begin < end && end <= n_features) {
855
- for (size_t ridx = 0 ; ridx < batch.Size (); ++ridx) {
844
+ size_t ridx_block_size = batch.Size () / n_threads_ + (batch.Size () % n_threads_ > 0 );
845
+ size_t min_ridx_block_size = 1024 ;
846
+ if ((n_features < static_cast <size_t >(n_threads_)) &&
847
+ (ridx_block_size > min_ridx_block_size)) {
848
+ /* Row-wise parallelisation.
849
+ */
850
+ std::vector<std::set<float >> categories_buff (n_threads_ * n_features);
851
+ std::vector<WQSketch> sketches_buff (n_threads_ * n_features);
852
+
853
+ #pragma omp parallel num_threads(n_threads_)
854
+ {
855
+ exc.Run ([&]() {
856
+ auto tid = static_cast <uint32_t >(omp_get_thread_num ());
857
+ WQSketch* sketches_th = sketches_buff.data () + tid * n_features;
858
+ std::set<float >* categories_th = categories_buff.data () + tid * n_features;
859
+
860
+ for (size_t ii = 0 ; ii < n_features; ii++) {
861
+ auto n_bins = std::min (static_cast <bst_idx_t >(max_bins_), columns_size_[ii]);
862
+ auto eps = 1.0 / (static_cast <float >(n_bins) * WQSketch::kFactor );
863
+ sketches_th[ii].Init (columns_size_[ii], eps);
864
+ }
865
+
866
+ size_t ridx_begin = tid * ridx_block_size;
867
+ size_t ridx_end = std::min (ridx_begin + ridx_block_size, batch.Size ());
868
+ for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
856
869
auto const &line = batch.GetLine (ridx);
857
870
auto w = weights[ridx + base_rowid];
858
871
if (is_dense) {
859
- for (size_t ii = begin ; ii < end ; ii++) {
872
+ for (size_t ii = 0 ; ii < n_features ; ii++) {
860
873
auto elem = line.GetElement (ii);
861
874
if (is_valid (elem)) {
862
875
if (IsCat (feature_types_, ii)) {
863
- categories_ [ii].emplace (elem.value );
876
+ categories_th [ii].emplace (elem.value );
864
877
} else {
865
- sketches_ [ii].Push (elem.value , w);
878
+ sketches_th [ii].Push (elem.value , w);
866
879
}
867
880
}
868
881
}
869
882
} else {
870
- for (size_t i = 0 ; i < line.Size (); ++i ) {
871
- auto const & elem = line.GetElement (i );
872
- if (is_valid (elem) && elem. column_idx >= begin && elem. column_idx < end ) {
883
+ for (size_t ii = 0 ; ii < line.Size (); ++ii ) {
884
+ auto elem = line.GetElement (ii );
885
+ if (is_valid (elem)) {
873
886
if (IsCat (feature_types_, elem.column_idx )) {
874
- categories_ [elem.column_idx ].emplace (elem.value );
887
+ categories_th [elem.column_idx ].emplace (elem.value );
875
888
} else {
876
- sketches_ [elem.column_idx ].Push (elem.value , w);
889
+ sketches_th [elem.column_idx ].Push (elem.value , w);
877
890
}
878
891
}
879
892
}
880
893
}
881
894
}
882
- }
883
- });
895
+ #pragma omp barrier
896
+
897
+ size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0 );
898
+ size_t fidx_begin = tid * fidx_block_size;
899
+ size_t fidx_end = std::min (fidx_begin + fidx_block_size, n_features);
900
+ for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
901
+ for (int th = 0 ; th < n_threads_; ++th) {
902
+ if (IsCat (feature_types_, ii)) {
903
+ categories_[ii].merge (categories_buff[th * n_features + ii]);
904
+ } else {
905
+ typename WQSketch::SummaryContainer summary;
906
+ sketches_buff[th * n_features + ii].GetSummary (&summary);
907
+ sketches_[ii].PushSummary (summary);
908
+ }
909
+ }
910
+ }
911
+ });
912
+ }
913
+ } else {
914
+ auto thread_columns_ptr = LoadBalance (batch, nnz, n_features, n_threads_, is_valid);
915
+ #pragma omp parallel num_threads(n_threads_)
916
+ {
917
+ exc.Run ([&]() {
918
+ auto tid = static_cast <uint32_t >(omp_get_thread_num ());
919
+ auto const begin = thread_columns_ptr[tid];
920
+ auto const end = thread_columns_ptr[tid + 1 ];
921
+
922
+ // do not iterate if no columns are assigned to the thread
923
+ if (begin < end && end <= n_features) {
924
+ for (size_t ridx = 0 ; ridx < batch.Size (); ++ridx) {
925
+ auto const &line = batch.GetLine (ridx);
926
+ auto w = weights[ridx + base_rowid];
927
+ if (is_dense) {
928
+ for (size_t ii = begin; ii < end; ii++) {
929
+ auto elem = line.GetElement (ii);
930
+ if (is_valid (elem)) {
931
+ if (IsCat (feature_types_, ii)) {
932
+ categories_[ii].emplace (elem.value );
933
+ } else {
934
+ sketches_[ii].Push (elem.value , w);
935
+ }
936
+ }
937
+ }
938
+ } else {
939
+ for (size_t i = 0 ; i < line.Size (); ++i) {
940
+ auto const &elem = line.GetElement (i);
941
+ if (is_valid (elem) && elem.column_idx >= begin && elem.column_idx < end) {
942
+ if (IsCat (feature_types_, elem.column_idx )) {
943
+ categories_[elem.column_idx ].emplace (elem.value );
944
+ } else {
945
+ sketches_[elem.column_idx ].Push (elem.value , w);
946
+ }
947
+ }
948
+ }
949
+ }
950
+ }
951
+ }
952
+ });
953
+ }
884
954
}
885
955
exc.Rethrow ();
886
956
}
0 commit comments