Skip to content

Commit 05d7000

Browse files
authored
Handle special characters in JSON model dump. (dmlc#9474)
1 parent f03463c commit 05d7000

File tree

7 files changed

+129
-105
lines changed

7 files changed

+129
-105
lines changed

src/common/common.cc

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
/*!
2-
* Copyright 2015-2019 by Contributors
3-
* \file common.cc
4-
* \brief Enable all kinds of global variables in common.
1+
/**
2+
* Copyright 2015-2023 by Contributors
53
*/
6-
#include <dmlc/thread_local.h>
7-
#include <xgboost/logging.h>
8-
94
#include "common.h"
10-
#include "./random.h"
115

12-
namespace xgboost {
13-
namespace common {
6+
#include <dmlc/thread_local.h> // for ThreadLocalStore
7+
8+
#include <cstdint> // for uint8_t
9+
#include <cstdio> // for snprintf, size_t
10+
#include <string> // for string
11+
12+
#include "./random.h" // for GlobalRandomEngine, GlobalRandom
13+
14+
namespace xgboost::common {
1415
/*! \brief thread local entry for random. */
1516
struct RandomThreadLocalEntry {
1617
/*! \brief the random engine instance. */
@@ -19,15 +20,43 @@ struct RandomThreadLocalEntry {
1920

2021
using RandomThreadLocalStore = dmlc::ThreadLocalStore<RandomThreadLocalEntry>;
2122

22-
GlobalRandomEngine& GlobalRandom() {
23-
return RandomThreadLocalStore::Get()->engine;
23+
GlobalRandomEngine &GlobalRandom() { return RandomThreadLocalStore::Get()->engine; }
24+
25+
void EscapeU8(std::string const &string, std::string *p_buffer) {
26+
auto &buffer = *p_buffer;
27+
for (size_t i = 0; i < string.length(); i++) {
28+
const auto ch = string[i];
29+
if (ch == '\\') {
30+
if (i < string.size() && string[i + 1] == 'u') {
31+
buffer += "\\";
32+
} else {
33+
buffer += "\\\\";
34+
}
35+
} else if (ch == '"') {
36+
buffer += "\\\"";
37+
} else if (ch == '\b') {
38+
buffer += "\\b";
39+
} else if (ch == '\f') {
40+
buffer += "\\f";
41+
} else if (ch == '\n') {
42+
buffer += "\\n";
43+
} else if (ch == '\r') {
44+
buffer += "\\r";
45+
} else if (ch == '\t') {
46+
buffer += "\\t";
47+
} else if (static_cast<uint8_t>(ch) <= 0x1f) {
48+
// Unit separator
49+
char buf[8];
50+
snprintf(buf, sizeof buf, "\\u%04x", ch);
51+
buffer += buf;
52+
} else {
53+
buffer += ch;
54+
}
55+
}
2456
}
2557

2658
#if !defined(XGBOOST_USE_CUDA)
27-
int AllVisibleGPUs() {
28-
return 0;
29-
}
59+
int AllVisibleGPUs() { return 0; }
3060
#endif // !defined(XGBOOST_USE_CUDA)
3161

32-
} // namespace common
33-
} // namespace xgboost
62+
} // namespace xgboost::common

src/common/common.h

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,19 @@
66
#ifndef XGBOOST_COMMON_COMMON_H_
77
#define XGBOOST_COMMON_COMMON_H_
88

9-
#include <xgboost/base.h>
10-
#include <xgboost/logging.h>
11-
#include <xgboost/span.h>
12-
13-
#include <algorithm>
14-
#include <exception>
15-
#include <functional>
16-
#include <limits>
17-
#include <numeric>
18-
#include <sstream>
19-
#include <string>
20-
#include <type_traits>
21-
#include <utility>
22-
#include <vector>
9+
#include <algorithm> // for max
10+
#include <array> // for array
11+
#include <cmath> // for ceil
12+
#include <cstddef> // for size_t
13+
#include <cstdint> // for int32_t, int64_t
14+
#include <sstream> // for basic_istream, operator<<, istringstream
15+
#include <string> // for string, basic_string, getline, char_traits
16+
#include <tuple> // for make_tuple
17+
#include <utility> // for forward, index_sequence, make_index_sequence
18+
#include <vector> // for vector
19+
20+
#include "xgboost/base.h" // for XGBOOST_DEVICE
21+
#include "xgboost/logging.h" // for LOG, LOG_FATAL, LogMessageFatal
2322

2423
#if defined(__CUDACC__)
2524
#include <thrust/system/cuda/error.h>
@@ -52,8 +51,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
5251
#endif // defined(__CUDACC__)
5352
} // namespace dh
5453

55-
namespace xgboost {
56-
namespace common {
54+
namespace xgboost::common {
5755
/*!
5856
* \brief Split a string by delimiter
5957
* \param s String to be split.
@@ -69,19 +67,13 @@ inline std::vector<std::string> Split(const std::string& s, char delim) {
6967
return ret;
7068
}
7169

70+
void EscapeU8(std::string const &string, std::string *p_buffer);
71+
7272
template <typename T>
7373
XGBOOST_DEVICE T Max(T a, T b) {
7474
return a < b ? b : a;
7575
}
7676

77-
// simple routine to convert any data to string
78-
template<typename T>
79-
inline std::string ToString(const T& data) {
80-
std::ostringstream os;
81-
os << data;
82-
return os.str();
83-
}
84-
8577
template <typename T1, typename T2>
8678
XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
8779
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
@@ -195,6 +187,5 @@ template <typename Indexable>
195187
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
196188
return indptr[group + 1] - 1;
197189
}
198-
} // namespace common
199-
} // namespace xgboost
190+
} // namespace xgboost::common
200191
#endif // XGBOOST_COMMON_COMMON_H_

src/common/json.cc

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
1-
/*!
2-
* Copyright (c) by Contributors 2019-2022
1+
/**
2+
* Copyright 2019-2023, XGBoost Contributors
33
*/
44
#include "xgboost/json.h"
55

6-
#include <dmlc/endian.h>
7-
8-
#include <cctype>
9-
#include <cmath>
10-
#include <cstddef>
11-
#include <iterator>
12-
#include <limits>
13-
#include <sstream>
14-
15-
#include "./math.h"
16-
#include "charconv.h"
17-
#include "xgboost/base.h"
18-
#include "xgboost/json_io.h"
19-
#include "xgboost/logging.h"
20-
#include "xgboost/string_view.h"
6+
#include <array> // for array
7+
#include <cctype> // for isdigit
8+
#include <cmath> // for isinf, isnan
9+
#include <cstdio> // for EOF
10+
#include <cstdlib> // for size_t, strtof
11+
#include <cstring> // for memcpy
12+
#include <initializer_list> // for initializer_list
13+
#include <iterator> // for distance
14+
#include <limits> // for numeric_limits
15+
#include <memory> // for allocator
16+
#include <sstream> // for operator<<, basic_ostream, operator&, ios, stringstream
17+
#include <system_error> // for errc
18+
19+
#include "./math.h" // for CheckNAN
20+
#include "charconv.h" // for to_chars, NumericLimits, from_chars, to_chars_result
21+
#include "common.h" // for EscapeU8
22+
#include "xgboost/base.h" // for XGBOOST_EXPECT
23+
#include "xgboost/intrusive_ptr.h" // for IntrusivePtr
24+
#include "xgboost/json_io.h" // for JsonReader, UBJReader, UBJWriter, JsonWriter, ToBigEn...
25+
#include "xgboost/logging.h" // for LOG, LOG_FATAL, LogMessageFatal, LogCheck_NE, CHECK
26+
#include "xgboost/string_view.h" // for StringView, operator<<
2127

2228
namespace xgboost {
2329

@@ -57,12 +63,12 @@ void JsonWriter::Visit(JsonObject const* obj) {
5763
}
5864

5965
void JsonWriter::Visit(JsonNumber const* num) {
60-
char number[NumericLimits<float>::kToCharsSize];
61-
auto res = to_chars(number, number + sizeof(number), num->GetNumber());
66+
std::array<char, NumericLimits<float>::kToCharsSize> number;
67+
auto res = to_chars(number.data(), number.data() + number.size(), num->GetNumber());
6268
auto end = res.ptr;
6369
auto ori_size = stream_->size();
64-
stream_->resize(stream_->size() + end - number);
65-
std::memcpy(stream_->data() + ori_size, number, end - number);
70+
stream_->resize(stream_->size() + end - number.data());
71+
std::memcpy(stream_->data() + ori_size, number.data(), end - number.data());
6672
}
6773

6874
void JsonWriter::Visit(JsonInteger const* num) {
@@ -88,43 +94,15 @@ void JsonWriter::Visit(JsonNull const* ) {
8894
}
8995

9096
void JsonWriter::Visit(JsonString const* str) {
91-
std::string buffer;
92-
buffer += '"';
93-
auto const& string = str->GetString();
94-
for (size_t i = 0; i < string.length(); i++) {
95-
const char ch = string[i];
96-
if (ch == '\\') {
97-
if (i < string.size() && string[i+1] == 'u') {
98-
buffer += "\\";
99-
} else {
100-
buffer += "\\\\";
101-
}
102-
} else if (ch == '"') {
103-
buffer += "\\\"";
104-
} else if (ch == '\b') {
105-
buffer += "\\b";
106-
} else if (ch == '\f') {
107-
buffer += "\\f";
108-
} else if (ch == '\n') {
109-
buffer += "\\n";
110-
} else if (ch == '\r') {
111-
buffer += "\\r";
112-
} else if (ch == '\t') {
113-
buffer += "\\t";
114-
} else if (static_cast<uint8_t>(ch) <= 0x1f) {
115-
// Unit separator
116-
char buf[8];
117-
snprintf(buf, sizeof buf, "\\u%04x", ch);
118-
buffer += buf;
119-
} else {
120-
buffer += ch;
121-
}
122-
}
123-
buffer += '"';
97+
std::string buffer;
98+
buffer += '"';
99+
auto const& string = str->GetString();
100+
common::EscapeU8(string, &buffer);
101+
buffer += '"';
124102

125-
auto s = stream_->size();
126-
stream_->resize(s + buffer.size());
127-
std::memcpy(stream_->data() + s, buffer.data(), buffer.size());
103+
auto s = stream_->size();
104+
stream_->resize(s + buffer.size());
105+
std::memcpy(stream_->data() + s, buffer.data(), buffer.size());
128106
}
129107

130108
void JsonWriter::Visit(JsonBoolean const* boolean) {

src/common/numeric.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstddef> // for size_t
1111
#include <cstdint> // for int32_t
1212
#include <iterator> // for iterator_traits
13+
#include <numeric> // for accumulate
1314
#include <vector>
1415

1516
#include "common.h" // AssertGPUSupport

src/learner.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ class LearnerConfiguration : public Learner {
797797
bool has_nc {cfg_.find("num_class") != cfg_.cend()};
798798
// Inject num_class into configuration.
799799
// FIXME(jiamingy): Remove the duplicated parameter in softmax
800-
cfg_["num_class"] = common::ToString(mparam_.num_class);
800+
cfg_["num_class"] = std::to_string(mparam_.num_class);
801801
auto& args = *p_args;
802802
args = {cfg_.cbegin(), cfg_.cend()}; // renew
803803
obj_->Configure(args);
@@ -1076,7 +1076,7 @@ class LearnerIO : public LearnerConfiguration {
10761076
mparam_.major_version = std::get<0>(Version::Self());
10771077
mparam_.minor_version = std::get<1>(Version::Self());
10781078

1079-
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
1079+
cfg_["num_feature"] = std::to_string(mparam_.num_feature);
10801080

10811081
auto n = tparam_.__DICT__();
10821082
cfg_.insert(n.cbegin(), n.cend());

src/tree/tree_model.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,11 +398,14 @@ class JsonGenerator : public TreeGenerator {
398398
static std::string const kIndicatorTemplate =
399399
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
400400
auto split_index = tree[nid].SplitIndex();
401+
auto fname = fmap_.Name(split_index);
402+
std::string qfname; // quoted
403+
common::EscapeU8(fname, &qfname);
401404
auto result = SuperT::Match(
402405
kIndicatorTemplate,
403406
{{"{nid}", std::to_string(nid)},
404407
{"{depth}", std::to_string(depth)},
405-
{"{fname}", fmap_.Name(split_index)},
408+
{"{fname}", qfname},
406409
{"{yes}", std::to_string(nyes)},
407410
{"{no}", std::to_string(tree[nid].DefaultChild())}});
408411
return result;
@@ -430,12 +433,14 @@ class JsonGenerator : public TreeGenerator {
430433
std::string const &template_str, std::string cond,
431434
uint32_t depth) const {
432435
auto split_index = tree[nid].SplitIndex();
436+
auto fname = split_index < fmap_.Size() ? fmap_.Name(split_index) : std::to_string(split_index);
437+
std::string qfname; // quoted
438+
common::EscapeU8(fname, &qfname);
433439
std::string const result = SuperT::Match(
434440
template_str,
435441
{{"{nid}", std::to_string(nid)},
436442
{"{depth}", std::to_string(depth)},
437-
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
438-
std::to_string(split_index)},
443+
{"{fname}", qfname},
439444
{"{cond}", cond},
440445
{"{left}", std::to_string(tree[nid].LeftChild())},
441446
{"{right}", std::to_string(tree[nid].RightChild())},

tests/python/test_basic_models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,26 @@ def validate_model(parameters):
439439
'objective': 'multi:softmax'}
440440
validate_model(parameters)
441441

442+
def test_special_model_dump_characters(self):
443+
params = {"objective": "reg:squarederror", "max_depth": 3}
444+
feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"]
445+
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
446+
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
447+
booster = xgb.train(params, Xy, num_boost_round=3)
448+
json_dump = booster.get_dump(dump_format="json")
449+
assert len(json_dump) == 3
450+
451+
def validate(obj: dict) -> None:
452+
for k, v in obj.items():
453+
if k == "split":
454+
assert v in feature_names
455+
elif isinstance(v, dict):
456+
validate(v)
457+
458+
for j_tree in json_dump:
459+
loaded = json.loads(j_tree)
460+
validate(loaded)
461+
442462
def test_categorical_model_io(self):
443463
X, y = tm.make_categorical(256, 16, 71, False)
444464
Xy = xgb.DMatrix(X, y, enable_categorical=True)

0 commit comments

Comments
 (0)