Skip to content

Commit bb56183

Browse files
authored
Normalize file system path. (dmlc#9463)
1 parent bdc1a3c commit bb56183

File tree

7 files changed

+141
-102
lines changed

7 files changed

+141
-102
lines changed

R-package/src/Makevars.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ OBJECTS= \
4747
$(PKGROOT)/src/data/data.o \
4848
$(PKGROOT)/src/data/sparse_page_raw_format.o \
4949
$(PKGROOT)/src/data/ellpack_page.o \
50+
$(PKGROOT)/src/data/file_iterator.o \
5051
$(PKGROOT)/src/data/gradient_index.o \
5152
$(PKGROOT)/src/data/gradient_index_page_source.o \
5253
$(PKGROOT)/src/data/gradient_index_format.o \

R-package/src/Makevars.win

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ OBJECTS= \
4747
$(PKGROOT)/src/data/data.o \
4848
$(PKGROOT)/src/data/sparse_page_raw_format.o \
4949
$(PKGROOT)/src/data/ellpack_page.o \
50+
$(PKGROOT)/src/data/file_iterator.o \
5051
$(PKGROOT)/src/data/gradient_index.o \
5152
$(PKGROOT)/src/data/gradient_index_page_source.o \
5253
$(PKGROOT)/src/data/gradient_index_format.o \

R-package/tests/testthat/test_dmatrix.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
7272
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
7373
tmp_file <- tempfile(fileext = ".libsvm")
7474
writeLines(tmp, tmp_file)
75+
expect_true(file.exists(tmp_file))
7576
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
7677
expect_equal(dim(dtest4), c(3, 4))
7778
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))

src/common/io.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include <cstddef> // for size_t
2929
#include <cstdint> // for int32_t, uint32_t
3030
#include <cstring> // for memcpy
31-
#include <filesystem> // for filesystem
31+
#include <filesystem> // for filesystem, weakly_canonical
3232
#include <fstream> // for ifstream
3333
#include <iterator> // for distance
3434
#include <limits> // for numeric_limits
@@ -154,7 +154,8 @@ std::string LoadSequentialFile(std::string uri, bool stream) {
154154
// Open in binary mode so that correct file size can be computed with
155155
// seekg(). This accommodates Windows platform:
156156
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
157-
std::ifstream ifs(std::filesystem::u8path(uri), std::ios_base::binary | std::ios_base::in);
157+
auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri));
158+
std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in);
158159
if (!ifs) {
159160
// https://stackoverflow.com/a/17338934
160161
OpenErr();

src/data/data.cc

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,57 @@
44
*/
55
#include "xgboost/data.h"
66

7-
#include <dmlc/registry.h>
8-
9-
#include <array>
10-
#include <cstddef>
11-
#include <cstring>
12-
13-
#include "../collective/communicator-inl.h"
14-
#include "../collective/communicator.h"
15-
#include "../common/algorithm.h" // for StableSort
16-
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
17-
#include "../common/common.h"
18-
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
19-
#include "../common/group_data.h"
20-
#include "../common/io.h"
21-
#include "../common/linalg_op.h"
22-
#include "../common/math.h"
23-
#include "../common/numeric.h" // for Iota
24-
#include "../common/threading_utils.h"
25-
#include "../common/version.h"
26-
#include "../data/adapter.h"
27-
#include "../data/iterative_dmatrix.h"
28-
#include "./sparse_page_dmatrix.h"
29-
#include "./sparse_page_source.h"
30-
#include "dmlc/io.h"
31-
#include "file_iterator.h"
32-
#include "simple_dmatrix.h"
33-
#include "sparse_page_writer.h"
34-
#include "validation.h"
35-
#include "xgboost/c_api.h"
36-
#include "xgboost/context.h"
37-
#include "xgboost/host_device_vector.h"
38-
#include "xgboost/learner.h"
39-
#include "xgboost/linalg.h" // Vector
40-
#include "xgboost/logging.h"
41-
#include "xgboost/string_view.h"
42-
#include "xgboost/version_config.h"
7+
#include <dmlc/registry.h> // for DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_LINK_TAG
8+
9+
#include <algorithm> // for copy, max, none_of, min
10+
#include <atomic> // for atomic
11+
#include <cmath> // for abs
12+
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
13+
#include <cstring> // for size_t, strcmp, memcpy
14+
#include <exception> // for exception
15+
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
16+
#include <map> // for map, operator!=
17+
#include <numeric> // for accumulate, partial_sum
18+
#include <tuple> // for get, apply
19+
#include <type_traits> // for remove_pointer_t, remove_reference
20+
21+
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
22+
#include "../collective/communicator.h" // for Operation
23+
#include "../common/algorithm.h" // for StableSort
24+
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
25+
#include "../common/common.h" // for Split
26+
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
27+
#include "../common/group_data.h" // for ParallelGroupBuilder
28+
#include "../common/io.h" // for PeekableInStream
29+
#include "../common/linalg_op.h" // for ElementWiseTransformHost
30+
#include "../common/math.h" // for CheckNAN
31+
#include "../common/numeric.h" // for Iota, RunLengthEncode
32+
#include "../common/threading_utils.h" // for ParallelFor
33+
#include "../common/version.h" // for Version
34+
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
35+
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
36+
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
37+
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
38+
#include "dmlc/base.h" // for BeginPtr
39+
#include "dmlc/common.h" // for OMPException
40+
#include "dmlc/data.h" // for Parser
41+
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
42+
#include "dmlc/io.h" // for Stream
43+
#include "dmlc/thread_local.h" // for ThreadLocalStore
44+
#include "ellpack_page.h" // for EllpackPage
45+
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
46+
#include "gradient_index.h" // for GHistIndexMatrix
47+
#include "simple_dmatrix.h" // for SimpleDMatrix
48+
#include "sparse_page_writer.h" // for SparsePageFormatReg
49+
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
50+
#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong
51+
#include "xgboost/context.h" // for Context
52+
#include "xgboost/host_device_vector.h" // for HostDeviceVector
53+
#include "xgboost/learner.h" // for HostDeviceVector
54+
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
55+
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
56+
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
57+
#include "xgboost/string_view.h" // for operator==, operator<<, StringView
4358

4459
namespace dmlc {
4560
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
@@ -811,29 +826,29 @@ DMatrix::~DMatrix() {
811826
}
812827
}
813828

814-
DMatrix *TryLoadBinary(std::string fname, bool silent) {
815-
int magic;
816-
std::unique_ptr<dmlc::Stream> fi(
817-
dmlc::Stream::Create(fname.c_str(), "r", true));
829+
namespace {
830+
DMatrix* TryLoadBinary(std::string fname, bool silent) {
831+
std::int32_t magic;
832+
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
818833
if (fi != nullptr) {
819834
common::PeekableInStream is(fi.get());
820835
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic)) {
821836
if (!DMLC_IO_NO_ENDIAN_SWAP) {
822837
dmlc::ByteSwap(&magic, sizeof(magic), 1);
823838
}
824839
if (magic == data::SimpleDMatrix::kMagic) {
825-
DMatrix *dmat = new data::SimpleDMatrix(&is);
840+
DMatrix* dmat = new data::SimpleDMatrix(&is);
826841
if (!silent) {
827-
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_
828-
<< " matrix with " << dmat->Info().num_nonzero_
829-
<< " entries loaded from " << fname;
842+
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
843+
<< dmat->Info().num_nonzero_ << " entries loaded from " << fname;
830844
}
831845
return dmat;
832846
}
833847
}
834848
}
835849
return nullptr;
836850
}
851+
} // namespace
837852

838853
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
839854
auto need_split = false;
@@ -845,7 +860,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
845860
}
846861

847862
std::string fname, cache_file;
848-
size_t dlm_pos = uri.find('#');
863+
auto dlm_pos = uri.find('#');
849864
if (dlm_pos != std::string::npos) {
850865
cache_file = uri.substr(dlm_pos + 1, uri.length());
851866
fname = uri.substr(0, dlm_pos);
@@ -857,14 +872,11 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
857872
for (size_t i = 0; i < cache_shards.size(); ++i) {
858873
size_t pos = cache_shards[i].rfind('.');
859874
if (pos == std::string::npos) {
860-
os << cache_shards[i]
861-
<< ".r" << collective::GetRank()
862-
<< "-" << collective::GetWorldSize();
875+
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
876+
<< collective::GetWorldSize();
863877
} else {
864-
os << cache_shards[i].substr(0, pos)
865-
<< ".r" << collective::GetRank()
866-
<< "-" << collective::GetWorldSize()
867-
<< cache_shards[i].substr(pos, cache_shards[i].length());
878+
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
879+
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
868880
}
869881
if (i + 1 != cache_shards.size()) {
870882
os << ':';
@@ -895,12 +907,12 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
895907
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
896908
}
897909

898-
data::ValidateFileFormat(fname);
899-
DMatrix* dmat {nullptr};
910+
DMatrix* dmat{nullptr};
900911

901912
if (cache_file.empty()) {
902-
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
903-
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
913+
fname = data::ValidateFileFormat(fname);
914+
std::unique_ptr<dmlc::Parser<std::uint32_t>> parser(
915+
dmlc::Parser<std::uint32_t>::Create(fname.c_str(), partid, npart, "auto"));
904916
data::FileAdapter adapter(parser.get());
905917
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
906918
cache_file, data_split_mode);

src/data/file_iterator.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/**
2+
* Copyright 2021-2023, XGBoost contributors
3+
*/
4+
#include "file_iterator.h"
5+
6+
#include <xgboost/logging.h> // for LogCheck_EQ, LogCheck_LE, CHECK_EQ, CHECK_LE, LOG, LOG_...
7+
8+
#include <filesystem> // for weakly_canonical, path, u8path
9+
#include <map> // for map, operator==
10+
#include <ostream> // for operator<<, basic_ostream, istringstream
11+
#include <vector> // for vector
12+
13+
#include "../common/common.h" // for Split
14+
#include "xgboost/string_view.h" // for operator<<, StringView
15+
16+
namespace xgboost::data {
17+
std::string ValidateFileFormat(std::string const& uri) {
18+
std::vector<std::string> name_args_cache = common::Split(uri, '#');
19+
CHECK_LE(name_args_cache.size(), 2)
20+
<< "Only one `#` is allowed in file path for cachefile specification";
21+
22+
std::vector<std::string> name_args = common::Split(name_args_cache[0], '?');
23+
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
24+
CHECK_EQ(name_args.size(), 2) << msg;
25+
26+
std::map<std::string, std::string> args;
27+
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
28+
for (size_t i = 0; i < arg_list.size(); ++i) {
29+
std::istringstream is(arg_list[i]);
30+
std::pair<std::string, std::string> kv;
31+
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
32+
<< " for key in arg " << i + 1;
33+
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
34+
<< " for value in arg " << i + 1;
35+
args.insert(kv);
36+
}
37+
if (args.find("format") == args.cend()) {
38+
LOG(FATAL) << msg;
39+
}
40+
41+
auto path = common::Split(uri, '?')[0];
42+
43+
namespace fs = std::filesystem;
44+
name_args[0] = fs::weakly_canonical(fs::u8path(path)).string();
45+
if (name_args_cache.size() == 1) {
46+
return name_args[0] + "?" + name_args[1];
47+
} else {
48+
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
49+
}
50+
}
51+
} // namespace xgboost::data

src/data/file_iterator.h

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,20 @@
44
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
55
#define XGBOOST_DATA_FILE_ITERATOR_H_
66

7-
#include <map>
8-
#include <memory>
9-
#include <string>
10-
#include <utility>
11-
#include <vector>
12-
13-
#include "array_interface.h"
14-
#include "dmlc/data.h"
15-
#include "xgboost/c_api.h"
16-
#include "xgboost/json.h"
17-
#include "xgboost/linalg.h"
18-
19-
namespace xgboost {
20-
namespace data {
21-
inline void ValidateFileFormat(std::string const& uri) {
22-
std::vector<std::string> name_cache = common::Split(uri, '#');
23-
CHECK_LE(name_cache.size(), 2)
24-
<< "Only one `#` is allowed in file path for cachefile specification";
25-
26-
std::vector<std::string> name_args = common::Split(name_cache[0], '?');
27-
CHECK_LE(name_args.size(), 2) << "only one `?` is allowed in file path.";
28-
29-
StringView msg{"URI parameter `format` is required for loading text data: filename?format=csv"};
30-
CHECK_EQ(name_args.size(), 2) << msg;
31-
32-
std::map<std::string, std::string> args;
33-
std::vector<std::string> arg_list = common::Split(name_args[1], '&');
34-
for (size_t i = 0; i < arg_list.size(); ++i) {
35-
std::istringstream is(arg_list[i]);
36-
std::pair<std::string, std::string> kv;
37-
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
38-
<< " for key in arg " << i + 1;
39-
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
40-
<< " for value in arg " << i + 1;
41-
args.insert(kv);
42-
}
43-
if (args.find("format") == args.cend()) {
44-
LOG(FATAL) << msg;
45-
}
46-
}
7+
#include <algorithm> // for max_element
8+
#include <cstddef> // for size_t
9+
#include <cstdint> // for uint32_t
10+
#include <memory> // for unique_ptr
11+
#include <string> // for string
12+
#include <utility> // for move
13+
14+
#include "dmlc/data.h" // for RowBlock, Parser
15+
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
16+
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
17+
#include "xgboost/logging.h" // for CHECK
18+
19+
namespace xgboost::data {
20+
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
4721

4822
/**
4923
* An iterator for implementing external memory support with file inputs. Users of
@@ -72,8 +46,7 @@ class FileIterator {
7246

7347
public:
7448
FileIterator(std::string uri, unsigned part_index, unsigned num_parts)
75-
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts} {
76-
ValidateFileFormat(uri_);
49+
: uri_{ValidateFileFormat(std::move(uri))}, part_idx_{part_index}, n_parts_{num_parts} {
7750
XGProxyDMatrixCreate(&proxy_);
7851
}
7952
~FileIterator() {
@@ -132,6 +105,5 @@ inline int Next(DataIterHandle self) {
132105
return static_cast<FileIterator*>(self)->Next();
133106
}
134107
} // namespace fileiter
135-
} // namespace data
136-
} // namespace xgboost
108+
} // namespace xgboost::data
137109
#endif // XGBOOST_DATA_FILE_ITERATOR_H_

0 commit comments

Comments
 (0)