@@ -840,47 +840,116 @@ class SketchContainerImpl {
840
840
template <typename Batch, typename IsValid>
841
841
void PushRowPageImpl (Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
842
842
size_t n_features, bool is_dense, IsValid is_valid) {
843
- auto thread_columns_ptr = LoadBalance (batch, nnz, n_features, n_threads_, is_valid);
844
-
845
843
dmlc::OMPException exc;
846
- #pragma omp parallel num_threads(n_threads_)
847
- {
848
- exc.Run ([&]() {
849
- auto tid = static_cast <uint32_t >(omp_get_thread_num ());
850
- auto const begin = thread_columns_ptr[tid];
851
- auto const end = thread_columns_ptr[tid + 1 ];
852
-
853
- // do not iterate if no columns are assigned to the thread
854
- if (begin < end && end <= n_features) {
855
- for (size_t ridx = 0 ; ridx < batch.Size (); ++ridx) {
844
+ size_t ridx_block_size = batch.Size () / n_threads_ + (batch.Size () % n_threads_ > 0 );
845
+ size_t min_ridx_block_size = 1024 ;
846
+ if ((n_features < n_threads_) && (ridx_block_size > min_ridx_block_size)) {
847
+ /* Row-wise parallelisation.
848
+ */
849
+ std::vector<std::set<float >> categories_buff (n_threads_ * n_features);
850
+ std::vector<WQSketch> sketches_buff (n_threads_ * n_features);
851
+
852
+ #pragma omp parallel num_threads(n_threads_)
853
+ {
854
+ exc.Run ([&]() {
855
+ auto tid = static_cast <uint32_t >(omp_get_thread_num ());
856
+ WQSketch* sketches_th = sketches_buff.data () + tid * n_features;
857
+ std::set<float >* categories_th = categories_buff.data () + tid * n_features;
858
+
859
+ for (size_t ii = 0 ; ii < n_features; ii++) {
860
+ auto n_bins = std::min (static_cast <bst_idx_t >(max_bins_), columns_size_[ii]);
861
+ auto eps = 1.0 / (static_cast <float >(n_bins) * WQSketch::kFactor );
862
+ sketches_th[ii].Init (columns_size_[ii], eps);
863
+ }
864
+
865
+ size_t ridx_begin = tid * ridx_block_size;
866
+ size_t ridx_end = std::min (ridx_begin + ridx_block_size, batch.Size ());
867
+ for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
856
868
auto const &line = batch.GetLine (ridx);
857
869
auto w = weights[ridx + base_rowid];
858
870
if (is_dense) {
859
- for (size_t ii = begin ; ii < end ; ii++) {
871
+ for (size_t ii = 0 ; ii < n_features ; ii++) {
860
872
auto elem = line.GetElement (ii);
861
873
if (is_valid (elem)) {
862
874
if (IsCat (feature_types_, ii)) {
863
- categories_ [ii].emplace (elem.value );
875
+ categories_th [ii].emplace (elem.value );
864
876
} else {
865
- sketches_ [ii].Push (elem.value , w);
877
+ sketches_th [ii].Push (elem.value , w);
866
878
}
867
879
}
868
880
}
869
881
} else {
870
- for (size_t i = 0 ; i < line.Size (); ++i ) {
871
- auto const & elem = line.GetElement (i );
872
- if (is_valid (elem) && elem. column_idx >= begin && elem. column_idx < end ) {
882
+ for (size_t ii = 0 ; ii < line.Size (); ++ii ) {
883
+ auto elem = line.GetElement (ii );
884
+ if (is_valid (elem)) {
873
885
if (IsCat (feature_types_, elem.column_idx )) {
874
- categories_ [elem.column_idx ].emplace (elem.value );
886
+ categories_th [elem.column_idx ].emplace (elem.value );
875
887
} else {
876
- sketches_ [elem.column_idx ].Push (elem.value , w);
888
+ sketches_th [elem.column_idx ].Push (elem.value , w);
877
889
}
878
890
}
879
891
}
880
892
}
881
893
}
882
- }
883
- });
894
+ #pragma omp barrier
895
+
896
+ size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0 );
897
+ size_t fidx_begin = tid * fidx_block_size;
898
+ size_t fidx_end = std::min (fidx_begin + fidx_block_size, n_features);
899
+ for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
900
+ for (size_t th = 0 ; th < n_threads_; ++th) {
901
+ if (IsCat (feature_types_, ii)) {
902
+ categories_[ii].merge (categories_buff[th * n_features + ii]);
903
+ } else {
904
+ typename WQSketch::SummaryContainer summary;
905
+ sketches_buff[th * n_features + ii].GetSummary (&summary);
906
+ sketches_[ii].PushSummary (summary);
907
+ }
908
+ }
909
+ }
910
+ });
911
+ }
912
+ } else {
913
+ auto thread_columns_ptr = LoadBalance (batch, nnz, n_features, n_threads_, is_valid);
914
+ #pragma omp parallel num_threads(n_threads_)
915
+ {
916
+ exc.Run ([&]() {
917
+ auto tid = static_cast <uint32_t >(omp_get_thread_num ());
918
+ auto const begin = thread_columns_ptr[tid];
919
+ auto const end = thread_columns_ptr[tid + 1 ];
920
+
921
+ // do not iterate if no columns are assigned to the thread
922
+ if (begin < end && end <= n_features) {
923
+ for (size_t ridx = 0 ; ridx < batch.Size (); ++ridx) {
924
+ auto const &line = batch.GetLine (ridx);
925
+ auto w = weights[ridx + base_rowid];
926
+ if (is_dense) {
927
+ for (size_t ii = begin; ii < end; ii++) {
928
+ auto elem = line.GetElement (ii);
929
+ if (is_valid (elem)) {
930
+ if (IsCat (feature_types_, ii)) {
931
+ categories_[ii].emplace (elem.value );
932
+ } else {
933
+ sketches_[ii].Push (elem.value , w);
934
+ }
935
+ }
936
+ }
937
+ } else {
938
+ for (size_t i = 0 ; i < line.Size (); ++i) {
939
+ auto const &elem = line.GetElement (i);
940
+ if (is_valid (elem) && elem.column_idx >= begin && elem.column_idx < end) {
941
+ if (IsCat (feature_types_, elem.column_idx )) {
942
+ categories_[elem.column_idx ].emplace (elem.value );
943
+ } else {
944
+ sketches_[elem.column_idx ].Push (elem.value , w);
945
+ }
946
+ }
947
+ }
948
+ }
949
+ }
950
+ }
951
+ });
952
+ }
884
953
}
885
954
exc.Rethrow ();
886
955
}
0 commit comments