Skip to content

Commit 3ef1703

Browse files
authored
Allow using string view to find JSON value. (dmlc#8332)
- Allow comparison between string and string view. - Fix compiler warnings.
1 parent 2959510 commit 3ef1703

File tree

8 files changed

+109
-107
lines changed

8 files changed

+109
-107
lines changed

include/xgboost/json.h

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,29 +187,31 @@ using I32Array = JsonTypedArray<int32_t, Value::ValueKind::kI32Array>;
187187
using I64Array = JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;
188188

189189
class JsonObject : public Value {
190-
std::map<std::string, Json> object_;
190+
public:
191+
using Map = std::map<std::string, Json, std::less<>>;
192+
193+
private:
194+
Map object_;
191195

192196
public:
193197
JsonObject() : Value(ValueKind::kObject) {}
194-
JsonObject(std::map<std::string, Json>&& object) noexcept; // NOLINT
198+
JsonObject(Map&& object) noexcept; // NOLINT
195199
JsonObject(JsonObject const& that) = delete;
196-
JsonObject(JsonObject && that) noexcept;
200+
JsonObject(JsonObject&& that) noexcept;
197201

198202
void Save(JsonWriter* writer) const override;
199203

200204
// silent the partial oveeridden warning
201205
Json& operator[](int ind) override { return Value::operator[](ind); }
202206
Json& operator[](std::string const& key) override { return object_[key]; }
203207

204-
std::map<std::string, Json> const& GetObject() && { return object_; }
205-
std::map<std::string, Json> const& GetObject() const & { return object_; }
206-
std::map<std::string, Json> & GetObject() & { return object_; }
208+
Map const& GetObject() && { return object_; }
209+
Map const& GetObject() const& { return object_; }
210+
Map& GetObject() & { return object_; }
207211

208212
bool operator==(Value const& rhs) const override;
209213

210-
static bool IsClassOf(Value const* value) {
211-
return value->Type() == ValueKind::kObject;
212-
}
214+
static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kObject; }
213215
~JsonObject() override = default;
214216
};
215217

@@ -559,16 +561,13 @@ std::vector<T> const& GetImpl(JsonTypedArray<T, kind> const& val) {
559561
}
560562

561563
// Object
562-
template <typename T,
563-
typename std::enable_if<
564-
std::is_same<T, JsonObject>::value>::type* = nullptr>
565-
std::map<std::string, Json>& GetImpl(T& val) { // NOLINT
564+
template <typename T, typename std::enable_if<std::is_same<T, JsonObject>::value>::type* = nullptr>
565+
JsonObject::Map& GetImpl(T& val) { // NOLINT
566566
return val.GetObject();
567567
}
568568
template <typename T,
569-
typename std::enable_if<
570-
std::is_same<T, JsonObject const>::value>::type* = nullptr>
571-
std::map<std::string, Json> const& GetImpl(T& val) { // NOLINT
569+
typename std::enable_if<std::is_same<T, JsonObject const>::value>::type* = nullptr>
570+
JsonObject::Map const& GetImpl(T& val) { // NOLINT
572571
return val.GetObject();
573572
}
574573
} // namespace detail

include/xgboost/string_view.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#ifndef XGBOOST_STRING_VIEW_H_
55
#define XGBOOST_STRING_VIEW_H_
66
#include <xgboost/logging.h>
7+
#include <xgboost/span.h>
78

89
#include <algorithm>
910
#include <iterator>
@@ -19,6 +20,7 @@ struct StringView {
1920
size_t size_{0};
2021

2122
public:
23+
using value_type = CharT; // NOLINT
2224
using iterator = const CharT*; // NOLINT
2325
using const_iterator = iterator; // NOLINT
2426
using reverse_iterator = std::reverse_iterator<const_iterator>; // NOLINT
@@ -77,5 +79,14 @@ inline bool operator==(StringView l, StringView r) {
7779
}
7880

7981
inline bool operator!=(StringView l, StringView r) { return !(l == r); }
82+
83+
inline bool operator<(StringView l, StringView r) {
84+
return common::Span<StringView::value_type const>{l.c_str(), l.size()} <
85+
common::Span<StringView::value_type const>{r.c_str(), r.size()};
86+
}
87+
88+
inline bool operator<(std::string const& l, StringView r) { return StringView{l} < r; }
89+
90+
inline bool operator<(StringView l, std::string const& r) { return l < StringView{r}; }
8091
} // namespace xgboost
8192
#endif // XGBOOST_STRING_VIEW_H_

src/c_api/c_api_utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ void TypeCheck(Json const &value, StringView name) {
258258
}
259259

260260
template <typename JT>
261-
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
261+
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
262262
auto const &obj = get<Object const>(in);
263263
auto it = obj.find(key);
264264
if (it == obj.cend() || IsA<Null>(it->second)) {
@@ -269,11 +269,11 @@ auto const &RequiredArg(Json const &in, std::string const &key, StringView func)
269269
}
270270

271271
template <typename JT, typename T>
272-
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
272+
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
273273
auto const &obj = get<Object const>(in);
274274
auto it = obj.find(key);
275275
if (it != obj.cend() && !IsA<Null>(it->second)) {
276-
TypeCheck<JT>(it->second, StringView{key});
276+
TypeCheck<JT>(it->second, key);
277277
return get<std::remove_const_t<JT> const>(it->second);
278278
}
279279
return dft;

src/common/json.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
199199
std::swap(that.object_, this->object_);
200200
}
201201

202-
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
203-
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
202+
JsonObject::JsonObject(Map&& object) noexcept
203+
: Value(ValueKind::kObject), object_{std::forward<Map>(object)} {}
204204

205205
bool JsonObject::operator==(Value const& rhs) const {
206206
if (!IsA<JsonObject>(&rhs)) {
@@ -502,7 +502,7 @@ Json JsonReader::ParseArray() {
502502
Json JsonReader::ParseObject() {
503503
GetConsecutiveChar('{');
504504

505-
std::map<std::string, Json> data;
505+
Object::Map data;
506506
SkipSpaces();
507507
char ch = PeekNextChar();
508508

@@ -777,7 +777,7 @@ std::string UBJReader::DecodeStr() {
777777

778778
Json UBJReader::ParseObject() {
779779
auto marker = PeekNextChar();
780-
std::map<std::string, Json> results;
780+
Object::Map results;
781781

782782
while (marker != '}') {
783783
auto str = this->DecodeStr();

src/data/array_interface.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class ArrayInterfaceHandler {
9999
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
100100

101101
template <typename PtrType>
102-
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const &obj) {
102+
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
103103
auto data_it = obj.find("data");
104104
if (data_it == obj.cend()) {
105105
LOG(FATAL) << "Empty data passed in.";
@@ -109,7 +109,7 @@ class ArrayInterfaceHandler {
109109
return p_data;
110110
}
111111

112-
static void Validate(std::map<std::string, Json> const &array) {
112+
static void Validate(Object::Map const &array) {
113113
auto version_it = array.find("version");
114114
if (version_it == array.cend()) {
115115
LOG(FATAL) << "Missing `version' field for array interface";
@@ -136,7 +136,7 @@ class ArrayInterfaceHandler {
136136

137137
// Find null mask (validity mask) field
138138
// Mask object is also an array interface, but with different requirements.
139-
static size_t ExtractMask(std::map<std::string, Json> const &column,
139+
static size_t ExtractMask(Object::Map const &column,
140140
common::Span<RBitField8::value_type> *p_out) {
141141
auto &s_mask = *p_out;
142142
if (column.find("mask") != column.cend()) {
@@ -208,7 +208,7 @@ class ArrayInterfaceHandler {
208208
}
209209

210210
template <int32_t D>
211-
static void ExtractShape(std::map<std::string, Json> const &array, size_t (&out_shape)[D]) {
211+
static void ExtractShape(Object::Map const &array, size_t (&out_shape)[D]) {
212212
auto const &j_shape = get<Array const>(array.at("shape"));
213213
std::vector<size_t> shape_arr(j_shape.size(), 0);
214214
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
@@ -229,7 +229,7 @@ class ArrayInterfaceHandler {
229229
* \brief Extracts the optiona `strides' field and returns whether the array is c-contiguous.
230230
*/
231231
template <int32_t D>
232-
static bool ExtractStride(std::map<std::string, Json> const &array, size_t itemsize,
232+
static bool ExtractStride(Object::Map const &array, size_t itemsize,
233233
size_t (&shape)[D], size_t (&stride)[D]) {
234234
auto strides_it = array.find("strides");
235235
// No stride is provided
@@ -272,7 +272,7 @@ class ArrayInterfaceHandler {
272272
return std::equal(stride_tmp, stride_tmp + D, stride);
273273
}
274274

275-
static void *ExtractData(std::map<std::string, Json> const &array, size_t size) {
275+
static void *ExtractData(Object::Map const &array, size_t size) {
276276
Validate(array);
277277
void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
278278
if (!p_data) {
@@ -378,7 +378,7 @@ class ArrayInterface {
378378
* to a vector of size n_samples. For for inputs like weights, this should be a 1
379379
* dimension column vector even though user might provide a matrix.
380380
*/
381-
void Initialize(std::map<std::string, Json> const &array) {
381+
void Initialize(Object::Map const &array) {
382382
ArrayInterfaceHandler::Validate(array);
383383

384384
auto typestr = get<String const>(array.at("typestr"));
@@ -413,7 +413,7 @@ class ArrayInterface {
413413

414414
public:
415415
ArrayInterface() = default;
416-
explicit ArrayInterface(std::map<std::string, Json> const &array) { this->Initialize(array); }
416+
explicit ArrayInterface(Object::Map const &array) { this->Initialize(array); }
417417

418418
explicit ArrayInterface(Json const &array) {
419419
if (IsA<Object>(array)) {

src/metric/auc.cu

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ struct DeviceAUCCache {
6060
};
6161

6262
template <bool is_multi>
63-
void InitCacheOnce(common::Span<float const> predts, int32_t device,
64-
std::shared_ptr<DeviceAUCCache>* p_cache) {
63+
void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCache> *p_cache) {
6564
auto& cache = *p_cache;
6665
if (!cache) {
6766
cache.reset(new DeviceAUCCache);
@@ -167,7 +166,7 @@ std::tuple<double, double, double>
167166
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
168167
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
169168
auto &cache = *p_cache;
170-
InitCacheOnce<false>(predts, device, p_cache);
169+
InitCacheOnce<false>(predts, p_cache);
171170

172171
/**
173172
* Create sorted index for each class
@@ -196,8 +195,7 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
196195
}
197196

198197
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
199-
common::Span<double> tp, common::Span<double> auc,
200-
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
198+
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
201199
dh::XGBDeviceAllocator<char> alloc;
202200
if (collective::IsDistributed()) {
203201
int32_t device = dh::CurrentDevice();
@@ -330,7 +328,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
330328
auto local_area = d_results.subspan(0, n_classes);
331329
auto tp = d_results.subspan(2 * n_classes, n_classes);
332330
auto auc = d_results.subspan(3 * n_classes, n_classes);
333-
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
331+
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
334332
}
335333

336334
/**
@@ -434,7 +432,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
434432
tp[c] = 1.0f;
435433
}
436434
});
437-
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
435+
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
438436
}
439437

440438
void MultiClassSortedIdx(common::Span<float const> predts,
@@ -458,7 +456,7 @@ double GPUMultiClassROCAUC(common::Span<float const> predts,
458456
std::shared_ptr<DeviceAUCCache> *p_cache,
459457
size_t n_classes) {
460458
auto& cache = *p_cache;
461-
InitCacheOnce<true>(predts, device, p_cache);
459+
InitCacheOnce<true>(predts, p_cache);
462460

463461
/**
464462
* Create sorted index for each class
@@ -486,7 +484,7 @@ std::pair<double, uint32_t>
486484
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
487485
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
488486
auto& cache = *p_cache;
489-
InitCacheOnce<false>(predts, device, p_cache);
487+
InitCacheOnce<false>(predts, p_cache);
490488

491489
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
492490
dh::XGBCachingDeviceAllocator<char> alloc;
@@ -606,7 +604,7 @@ std::tuple<double, double, double>
606604
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
607605
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
608606
auto& cache = *p_cache;
609-
InitCacheOnce<false>(predts, device, p_cache);
607+
InitCacheOnce<false>(predts, p_cache);
610608

611609
/**
612610
* Create sorted index for each class
@@ -647,7 +645,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
647645
std::shared_ptr<DeviceAUCCache> *p_cache,
648646
size_t n_classes) {
649647
auto& cache = *p_cache;
650-
InitCacheOnce<true>(predts, device, p_cache);
648+
InitCacheOnce<true>(predts, p_cache);
651649

652650
/**
653651
* Create sorted index for each class
@@ -827,7 +825,7 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
827825
}
828826

829827
auto &cache = *p_cache;
830-
InitCacheOnce<false>(predts, device, p_cache);
828+
InitCacheOnce<false>(predts, p_cache);
831829

832830
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
833831
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());

tests/cpp/common/test_json.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,7 @@ TEST(Json, WrongCasts) {
499499
ASSERT_ANY_THROW(get<Number>(json));
500500
}
501501
{
502-
Json json = Json{ Object{std::map<std::string, Json>{
503-
{"key", Json{String{"value"}}}} } };
502+
Json json = Json{Object{{{"key", Json{String{"value"}}}}}};
504503
ASSERT_ANY_THROW(get<Number>(json));
505504
}
506505
}

0 commit comments

Comments
 (0)