Skip to content

Commit 2a3ba30

Browse files
authored
Implement the container for categories. (dmlc#11297)
1 parent 4a28128 commit 2a3ba30

17 files changed

+1052
-205
lines changed

include/xgboost/json.h

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2024, XGBoost Contributors
2+
* Copyright 2019-2025, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_JSON_H_
55
#define XGBOOST_JSON_H_
@@ -25,29 +25,28 @@ class JsonWriter;
2525
class Value {
2626
private:
2727
mutable class IntrusivePtrCell ref_;
28-
friend IntrusivePtrCell &
29-
IntrusivePtrRefCount(xgboost::Value const *t) noexcept {
28+
friend IntrusivePtrCell& IntrusivePtrRefCount(xgboost::Value const* t) noexcept {
3029
return t->ref_;
3130
}
3231

3332
public:
3433
/*!\brief Simplified implementation of LLVM RTTI. */
35-
enum class ValueKind {
36-
kString,
37-
kNumber,
38-
kInteger,
39-
kObject, // std::map
40-
kArray, // std::vector
41-
kBoolean,
42-
kNull,
34+
enum class ValueKind : std::int64_t {
35+
kString = 0,
36+
kNumber = 1,
37+
kInteger = 2,
38+
kObject = 3, // std::map
39+
kArray = 4, // std::vector
40+
kBoolean = 5,
41+
kNull = 6,
4342
// typed array for ubjson
44-
kF32Array,
45-
kF64Array,
46-
kI8Array,
47-
kU8Array,
48-
kI16Array,
49-
kI32Array,
50-
kI64Array
43+
kF32Array = 7,
44+
kF64Array = 8,
45+
kI8Array = 9,
46+
kU8Array = 10,
47+
kI16Array = 11,
48+
kI32Array = 12,
49+
kI64Array = 13
5150
};
5251

5352
explicit Value(ValueKind _kind) : kind_{_kind} {}
@@ -152,7 +151,7 @@ class JsonTypedArray : public Value {
152151
std::vector<T> vec_;
153152

154153
public:
155-
using Type = T;
154+
using value_type = T; // NOLINT
156155

157156
JsonTypedArray() : Value(kind) {}
158157
explicit JsonTypedArray(std::size_t n) : Value(kind) { vec_.resize(n); }

include/xgboost/json_io.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class UBJReader : public JsonReader {
223223
auto ParseTypedArray(std::int64_t n) {
224224
TypedArray results{static_cast<size_t>(n)};
225225
for (int64_t i = 0; i < n; ++i) {
226-
auto v = this->ReadPrimitive<typename TypedArray::Type>();
226+
auto v = this->ReadPrimitive<typename TypedArray::value_type>();
227227
results.Set(i, v);
228228
}
229229
return Json{std::move(results)};

src/data/cat_container.cc

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/**
2+
* Copyright 2025, XGBoost Contributors
3+
*/
4+
#include "cat_container.h"
5+
6+
#include <algorithm> // for copy
7+
#include <cstddef> // for size_t
8+
#include <memory> // for make_unique
9+
#include <utility> // for move
10+
#include <vector> // for vector
11+
12+
#include "../encoder/types.h" // for Overloaded
13+
#include "xgboost/json.h" // for Json
14+
15+
namespace xgboost {
16+
CatContainer::CatContainer(enc::HostColumnsView const& df) : CatContainer{} {
17+
this->n_total_cats_ = df.n_total_cats;
18+
19+
this->feature_segments_.Resize(df.feature_segments.size());
20+
auto& seg = this->feature_segments_.HostVector();
21+
std::copy_n(df.feature_segments.data(), df.feature_segments.size(), seg.begin());
22+
23+
for (auto const& col : df.columns) {
24+
std::visit(enc::Overloaded{
25+
[this](enc::CatStrArrayView str) {
26+
using T = typename cpu_impl::ViewToStorageImpl<enc::CatStrArrayView>::Type;
27+
this->cpu_impl_->columns.emplace_back();
28+
this->cpu_impl_->columns.back().emplace<T>();
29+
auto& v = std::get<T>(this->cpu_impl_->columns.back());
30+
v.offsets.resize(str.offsets.size());
31+
v.values.resize(str.values.size());
32+
std::copy_n(str.offsets.data(), str.offsets.size(), v.offsets.data());
33+
std::copy_n(str.values.data(), str.values.size(), v.values.data());
34+
},
35+
[this](auto&& values) {
36+
using T =
37+
typename cpu_impl::ViewToStorageImpl<std::decay_t<decltype(values)>>::Type;
38+
this->cpu_impl_->columns.emplace_back();
39+
this->cpu_impl_->columns.back().emplace<T>();
40+
auto& v = std::get<T>(this->cpu_impl_->columns.back());
41+
v.resize(values.size());
42+
std::copy_n(values.data(), values.size(), v.data());
43+
}},
44+
col);
45+
}
46+
47+
this->sorted_idx_.Resize(0);
48+
this->cpu_impl_->Finalize();
49+
50+
CHECK(!this->DeviceCanRead());
51+
CHECK(this->HostCanRead());
52+
CHECK_EQ(this->n_total_cats_, df.feature_segments.back());
53+
CHECK_GE(this->n_total_cats_, 0) << "Too many categories.";
54+
}
55+
56+
namespace {
57+
template <typename T>
58+
struct PrimToUbj;
59+
60+
template <>
61+
struct PrimToUbj<std::int8_t> {
62+
using Type = I8Array;
63+
};
64+
template <>
65+
struct PrimToUbj<std::int16_t> {
66+
using Type = I16Array;
67+
};
68+
template <>
69+
struct PrimToUbj<std::int32_t> {
70+
using Type = I32Array;
71+
};
72+
template <>
73+
struct PrimToUbj<std::int64_t> {
74+
using Type = I64Array;
75+
};
76+
template <>
77+
struct PrimToUbj<float> {
78+
using Type = F32Array;
79+
};
80+
template <>
81+
struct PrimToUbj<double> {
82+
using Type = F64Array;
83+
};
84+
} // anonymous namespace
85+
86+
void CatContainer::Save(Json* p_out) const {
87+
[[maybe_unused]] auto _ = this->HostView();
88+
auto& out = *p_out;
89+
90+
auto const& columns = this->cpu_impl_->columns;
91+
std::vector<Json> arr(this->cpu_impl_->columns.size());
92+
for (std::size_t fidx = 0, n_features = columns.size(); fidx < n_features; ++fidx) {
93+
auto& f_out = arr[fidx];
94+
95+
auto const& col = columns[fidx];
96+
std::visit(enc::Overloaded{
97+
[&f_out](cpu_impl::CatStrArray const& str) {
98+
f_out = Object{};
99+
I32Array joffsets{str.offsets.size()};
100+
auto const& f_offsets = str.offsets;
101+
std::copy(f_offsets.cbegin(), f_offsets.cend(), joffsets.GetArray().begin());
102+
f_out["offsets"] = std::move(joffsets);
103+
104+
I8Array jnames{str.values.size()}; // fixme: uint8
105+
auto const& f_names = str.values;
106+
std::copy(f_names.cbegin(), f_names.cend(), jnames.GetArray().begin());
107+
f_out["values"] = std::move(jnames);
108+
},
109+
[&f_out](auto&& values) {
110+
using T =
111+
std::remove_cv_t<typename std::decay_t<decltype(values)>::value_type>;
112+
using JT = typename PrimToUbj<T>::Type;
113+
JT array{values.size()};
114+
std::copy_n(values.data(), values.size(), array.GetArray().begin());
115+
116+
Object out{};
117+
out["values"] = std::move(array);
118+
out["type"] = static_cast<std::int64_t>(array.Type());
119+
120+
f_out = std::move(out);
121+
}},
122+
col);
123+
}
124+
125+
auto jf_segments = I32Array{this->feature_segments_.Size()};
126+
auto const& hf_segments = this->feature_segments_.ConstHostVector();
127+
std::copy(hf_segments.cbegin(), hf_segments.cend(), jf_segments.GetArray().begin());
128+
129+
auto jsorted_index = I32Array{this->sorted_idx_.Size()};
130+
auto const& h_sorted_idx = this->sorted_idx_.ConstHostVector();
131+
std::copy_n(h_sorted_idx.cbegin(), h_sorted_idx.size(), jsorted_index.GetArray().begin());
132+
133+
out = Object{};
134+
out["sorted_idx"] = std::move(jsorted_index);
135+
out["feature_segments"] = std::move(jf_segments);
136+
out["enc"] = arr;
137+
}
138+
139+
namespace {
140+
// Dispatch method for JSON and UBJSON
141+
template <typename U, typename Vec>
142+
void LoadJson(Json jvalues, Vec* p_out) {
143+
if (IsA<Array>(jvalues)) {
144+
auto const& jarray = get<Array const>(jvalues);
145+
std::vector<U> buf(jarray.size());
146+
for (std::size_t i = 0, n = jarray.size(); i < n; ++i) {
147+
buf[i] = static_cast<U>(get<Integer const>(jarray[i]));
148+
}
149+
*p_out = std::move(buf);
150+
return;
151+
}
152+
auto const& values = get<std::add_const_t<typename PrimToUbj<U>::Type>>(jvalues);
153+
*p_out = std::move(values);
154+
}
155+
} // namespace
156+
157+
void CatContainer::Load(Json const& in) {
158+
auto array = get<Array const>(in["enc"]);
159+
auto n_features = array.size();
160+
161+
auto& columns = this->cpu_impl_->columns;
162+
for (std::size_t fidx = 0; fidx < n_features; ++fidx) {
163+
auto const& column = get<Object>(array[fidx]);
164+
auto it = column.find("offsets");
165+
if (it != column.cend()) {
166+
// str
167+
cpu_impl::CatStrArray str{};
168+
LoadJson<std::int32_t>(column.at("offsets"), &str.offsets);
169+
LoadJson<enc::CatCharT>(column.at("values"), &str.values);
170+
171+
columns.emplace_back(str);
172+
} else {
173+
// numeric
174+
auto type = get<Integer const>(column.at("type"));
175+
using T = Value::ValueKind;
176+
auto const& jvalues = column.at("values");
177+
columns.emplace_back();
178+
switch (static_cast<Value::ValueKind>(type)) {
179+
case T::kI8Array: {
180+
LoadJson<std::int8_t>(jvalues, &columns.back());
181+
break;
182+
}
183+
case T::kI16Array: {
184+
LoadJson<std::int16_t>(jvalues, &columns.back());
185+
break;
186+
}
187+
case T::kI32Array: {
188+
LoadJson<std::int32_t>(jvalues, &columns.back());
189+
break;
190+
}
191+
case T::kI64Array: {
192+
LoadJson<std::int64_t>(jvalues, &columns.back());
193+
break;
194+
}
195+
case T::kF32Array: {
196+
LoadJson<float>(jvalues, &columns.back());
197+
break;
198+
}
199+
case T::kF64Array: {
200+
LoadJson<double>(jvalues, &columns.back());
201+
break;
202+
}
203+
default: {
204+
LOG(FATAL) << "Invalid type.";
205+
}
206+
}
207+
}
208+
}
209+
210+
auto& hf_segments = this->feature_segments_.HostVector();
211+
LoadJson<std::int32_t>(in["feature_segments"], &hf_segments);
212+
CHECK(!hf_segments.empty());
213+
this->n_total_cats_ = hf_segments.back();
214+
215+
auto& h_sorted_idx = this->sorted_idx_.HostVector();
216+
LoadJson<std::int32_t>(in["sorted_idx"], &h_sorted_idx);
217+
218+
this->cpu_impl_->Finalize();
219+
}
220+
221+
#if !defined(XGBOOST_USE_CUDA)
222+
CatContainer::CatContainer() : cpu_impl_{std::make_unique<cpu_impl::CatContainerImpl>()} {}
223+
224+
CatContainer::~CatContainer() = default;
225+
226+
void CatContainer::Copy(Context const*, CatContainer const& that) { this->CopyCommon(that); }
227+
228+
[[nodiscard]] enc::HostColumnsView CatContainer::HostView() const { return this->HostViewImpl(); }
229+
230+
void CatContainer::Sort(Context const* ctx) {
231+
CHECK(ctx->IsCPU());
232+
auto view = this->HostView();
233+
this->sorted_idx_.HostVector().resize(view.n_total_cats);
234+
enc::SortNames(enc::Policy<EncErrorPolicy>{}, view, this->sorted_idx_.HostSpan());
235+
}
236+
237+
[[nodiscard]] bool CatContainer::DeviceCanRead() const { return false; }
238+
#endif // !defined(XGBOOST_USE_CUDA)
239+
} // namespace xgboost

0 commit comments

Comments
 (0)