4
4
*/
5
5
#include " xgboost/data.h"
6
6
7
- #include < dmlc/registry.h>
8
-
9
- #include < array>
10
- #include < cstddef>
11
- #include < cstring>
12
-
13
- #include " ../collective/communicator-inl.h"
14
- #include " ../collective/communicator.h"
15
- #include " ../common/algorithm.h" // for StableSort
16
- #include " ../common/api_entry.h" // for XGBAPIThreadLocalEntry
17
- #include " ../common/common.h"
18
- #include " ../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
19
- #include " ../common/group_data.h"
20
- #include " ../common/io.h"
21
- #include " ../common/linalg_op.h"
22
- #include " ../common/math.h"
23
- #include " ../common/numeric.h" // for Iota
24
- #include " ../common/threading_utils.h"
25
- #include " ../common/version.h"
26
- #include " ../data/adapter.h"
27
- #include " ../data/iterative_dmatrix.h"
28
- #include " ./sparse_page_dmatrix.h"
29
- #include " ./sparse_page_source.h"
30
- #include " dmlc/io.h"
31
- #include " file_iterator.h"
32
- #include " simple_dmatrix.h"
33
- #include " sparse_page_writer.h"
34
- #include " validation.h"
35
- #include " xgboost/c_api.h"
36
- #include " xgboost/context.h"
37
- #include " xgboost/host_device_vector.h"
38
- #include " xgboost/learner.h"
39
- #include " xgboost/linalg.h" // Vector
40
- #include " xgboost/logging.h"
41
- #include " xgboost/string_view.h"
42
- #include " xgboost/version_config.h"
7
+ #include < dmlc/registry.h> // for DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_LINK_TAG
8
+
9
+ #include < algorithm> // for copy, max, none_of, min
10
+ #include < atomic> // for atomic
11
+ #include < cmath> // for abs
12
+ #include < cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
13
+ #include < cstring> // for size_t, strcmp, memcpy
14
+ #include < exception> // for exception
15
+ #include < iostream> // for operator<<, basic_ostream, basic_ostream::op...
16
+ #include < map> // for map, operator!=
17
+ #include < numeric> // for accumulate, partial_sum
18
+ #include < tuple> // for get, apply
19
+ #include < type_traits> // for remove_pointer_t, remove_reference
20
+
21
+ #include " ../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
22
+ #include " ../collective/communicator.h" // for Operation
23
+ #include " ../common/algorithm.h" // for StableSort
24
+ #include " ../common/api_entry.h" // for XGBAPIThreadLocalEntry
25
+ #include " ../common/common.h" // for Split
26
+ #include " ../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
27
+ #include " ../common/group_data.h" // for ParallelGroupBuilder
28
+ #include " ../common/io.h" // for PeekableInStream
29
+ #include " ../common/linalg_op.h" // for ElementWiseTransformHost
30
+ #include " ../common/math.h" // for CheckNAN
31
+ #include " ../common/numeric.h" // for Iota, RunLengthEncode
32
+ #include " ../common/threading_utils.h" // for ParallelFor
33
+ #include " ../common/version.h" // for Version
34
+ #include " ../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
35
+ #include " ../data/iterative_dmatrix.h" // for IterativeDMatrix
36
+ #include " ./sparse_page_dmatrix.h" // for SparsePageDMatrix
37
+ #include " array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
38
+ #include " dmlc/base.h" // for BeginPtr
39
+ #include " dmlc/common.h" // for OMPException
40
+ #include " dmlc/data.h" // for Parser
41
+ #include " dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
42
+ #include " dmlc/io.h" // for Stream
43
+ #include " dmlc/thread_local.h" // for ThreadLocalStore
44
+ #include " ellpack_page.h" // for EllpackPage
45
+ #include " file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
46
+ #include " gradient_index.h" // for GHistIndexMatrix
47
+ #include " simple_dmatrix.h" // for SimpleDMatrix
48
+ #include " sparse_page_writer.h" // for SparsePageFormatReg
49
+ #include " validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
50
+ #include " xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong
51
+ #include " xgboost/context.h" // for Context
52
+ #include " xgboost/host_device_vector.h" // for HostDeviceVector
53
+ #include " xgboost/learner.h" // for HostDeviceVector
54
+ #include " xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
55
+ #include " xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
56
+ #include " xgboost/span.h" // for Span, operator!=, SpanIterator
57
+ #include " xgboost/string_view.h" // for operator==, operator<<, StringView
43
58
44
59
namespace dmlc {
45
60
DMLC_REGISTRY_ENABLE (::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
@@ -811,29 +826,29 @@ DMatrix::~DMatrix() {
811
826
}
812
827
}
813
828
814
- DMatrix * TryLoadBinary (std::string fname, bool silent) {
815
- int magic;
816
- std::unique_ptr<dmlc::Stream> fi (
817
- dmlc::Stream::Create (fname.c_str (), " r" , true ));
829
+ namespace {
830
+ DMatrix* TryLoadBinary (std::string fname, bool silent) {
831
+ std::int32_t magic;
832
+ std::unique_ptr<dmlc::Stream> fi ( dmlc::Stream::Create (fname.c_str (), " r" , true ));
818
833
if (fi != nullptr ) {
819
834
common::PeekableInStream is (fi.get ());
820
835
if (is.PeekRead (&magic, sizeof (magic)) == sizeof (magic)) {
821
836
if (!DMLC_IO_NO_ENDIAN_SWAP) {
822
837
dmlc::ByteSwap (&magic, sizeof (magic), 1 );
823
838
}
824
839
if (magic == data::SimpleDMatrix::kMagic ) {
825
- DMatrix * dmat = new data::SimpleDMatrix (&is);
840
+ DMatrix* dmat = new data::SimpleDMatrix (&is);
826
841
if (!silent) {
827
- LOG (CONSOLE) << dmat->Info ().num_row_ << ' x' << dmat->Info ().num_col_
828
- << " matrix with " << dmat->Info ().num_nonzero_
829
- << " entries loaded from " << fname;
842
+ LOG (CONSOLE) << dmat->Info ().num_row_ << ' x' << dmat->Info ().num_col_ << " matrix with "
843
+ << dmat->Info ().num_nonzero_ << " entries loaded from " << fname;
830
844
}
831
845
return dmat;
832
846
}
833
847
}
834
848
}
835
849
return nullptr ;
836
850
}
851
+ } // namespace
837
852
838
853
DMatrix* DMatrix::Load (const std::string& uri, bool silent, DataSplitMode data_split_mode) {
839
854
auto need_split = false ;
@@ -845,7 +860,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
845
860
}
846
861
847
862
std::string fname, cache_file;
848
- size_t dlm_pos = uri.find (' #' );
863
+ auto dlm_pos = uri.find (' #' );
849
864
if (dlm_pos != std::string::npos) {
850
865
cache_file = uri.substr (dlm_pos + 1 , uri.length ());
851
866
fname = uri.substr (0 , dlm_pos);
@@ -857,14 +872,11 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
857
872
for (size_t i = 0 ; i < cache_shards.size (); ++i) {
858
873
size_t pos = cache_shards[i].rfind (' .' );
859
874
if (pos == std::string::npos) {
860
- os << cache_shards[i]
861
- << " .r" << collective::GetRank ()
862
- << " -" << collective::GetWorldSize ();
875
+ os << cache_shards[i] << " .r" << collective::GetRank () << " -"
876
+ << collective::GetWorldSize ();
863
877
} else {
864
- os << cache_shards[i].substr (0 , pos)
865
- << " .r" << collective::GetRank ()
866
- << " -" << collective::GetWorldSize ()
867
- << cache_shards[i].substr (pos, cache_shards[i].length ());
878
+ os << cache_shards[i].substr (0 , pos) << " .r" << collective::GetRank () << " -"
879
+ << collective::GetWorldSize () << cache_shards[i].substr (pos, cache_shards[i].length ());
868
880
}
869
881
if (i + 1 != cache_shards.size ()) {
870
882
os << ' :' ;
@@ -895,12 +907,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
895
907
LOG (CONSOLE) << " Load part of data " << partid << " of " << npart << " parts" ;
896
908
}
897
909
898
- data::ValidateFileFormat (fname);
899
- DMatrix* dmat {nullptr };
910
+ DMatrix* dmat{nullptr };
900
911
901
912
if (cache_file.empty ()) {
902
- std::unique_ptr<dmlc::Parser<uint32_t >> parser (
903
- dmlc::Parser<uint32_t >::Create (fname.c_str (), partid, npart, " auto" ));
913
+ fname = data::ValidateFileFormat (fname);
914
+ std::unique_ptr<dmlc::Parser<std::uint32_t >> parser (
915
+ dmlc::Parser<std::uint32_t >::Create (fname.c_str (), partid, npart, " auto" ));
904
916
data::FileAdapter adapter (parser.get ());
905
917
dmat = DMatrix::Create (&adapter, std::numeric_limits<float >::quiet_NaN (), Context{}.Threads (),
906
918
cache_file, data_split_mode);
0 commit comments