Skip to content

Commit db14e3f

Browse files
authored
Support null value in CUDA array interface. (dmlc#8486) (dmlc#8499)
1 parent 9372370 commit db14e3f

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

src/data/array_interface.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
101101
template <typename PtrType>
102102
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
103103
auto data_it = obj.find("data");
104-
if (data_it == obj.cend()) {
104+
if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
105105
LOG(FATAL) << "Empty data passed in.";
106106
}
107107
auto p_data = reinterpret_cast<PtrType>(
@@ -111,25 +111,27 @@ class ArrayInterfaceHandler {
111111

112112
static void Validate(Object::Map const &array) {
113113
auto version_it = array.find("version");
114-
if (version_it == array.cend()) {
114+
if (version_it == array.cend() || IsA<Null>(version_it->second)) {
115115
LOG(FATAL) << "Missing `version' field for array interface";
116116
}
117117
if (get<Integer const>(version_it->second) > 3) {
118118
LOG(FATAL) << ArrayInterfaceErrors::Version();
119119
}
120120

121121
auto typestr_it = array.find("typestr");
122-
if (typestr_it == array.cend()) {
122+
if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
123123
LOG(FATAL) << "Missing `typestr' field for array interface";
124124
}
125125

126126
auto typestr = get<String const>(typestr_it->second);
127127
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
128128

129-
if (array.find("shape") == array.cend()) {
129+
auto shape_it = array.find("shape");
130+
if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
130131
LOG(FATAL) << "Missing `shape' field for array interface";
131132
}
132-
if (array.find("data") == array.cend()) {
133+
auto data_it = array.find("data");
134+
if (data_it == array.cend() || IsA<Null>(data_it->second)) {
133135
LOG(FATAL) << "Missing `data' field for array interface";
134136
}
135137
}
@@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
139141
static size_t ExtractMask(Object::Map const &column,
140142
common::Span<RBitField8::value_type> *p_out) {
141143
auto &s_mask = *p_out;
142-
if (column.find("mask") != column.cend()) {
143-
auto const &j_mask = get<Object const>(column.at("mask"));
144+
auto const &mask_it = column.find("mask");
145+
if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
146+
auto const &j_mask = get<Object const>(mask_it->second);
144147
Validate(j_mask);
145148

146149
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
@@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
173176
// assume 1 byte alignment.
174177
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
175178

176-
if (j_mask.find("strides") != j_mask.cend()) {
177-
auto strides = get<Array const>(column.at("strides"));
179+
auto strides_it = j_mask.find("strides");
180+
if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
181+
auto strides = get<Array const>(strides_it->second);
178182
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
179183
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
180184
}
@@ -401,7 +405,9 @@ class ArrayInterface {
401405
<< "XGBoost doesn't support internal broadcasting.";
402406
}
403407
} else {
404-
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported.";
408+
auto mask_it = array.find("mask");
409+
CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
410+
<< "Masked array is not yet supported.";
405411
}
406412

407413
auto stream_it = array.find("stream");

tests/cpp/data/test_array_interface.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ TEST(ArrayInterface, Error) {
3333
Json column { Object() };
3434
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
3535
column["shape"] = Array(j_shape);
36-
std::vector<Json> j_data {
37-
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
38-
Json(Boolean(false))};
36+
std::vector<Json> j_data{Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
37+
Json(Boolean(false))};
3938

4039
auto const& column_obj = get<Object>(column);
4140
std::string typestr{"<f4"};
@@ -45,6 +44,10 @@ TEST(ArrayInterface, Error) {
4544
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
4645
column["version"] = 3;
4746
// missing data
47+
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
48+
dmlc::Error);
49+
// null data
50+
column["data"] = Null{};
4851
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
4952
dmlc::Error);
5053
column["data"] = j_data;
@@ -63,6 +66,11 @@ TEST(ArrayInterface, Error) {
6366
Json(Boolean(false))};
6467
column["data"] = j_data;
6568
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
69+
// null data in mask
70+
column["mask"] = Object{};
71+
column["mask"]["data"] = Null{};
72+
common::Span<RBitField8::value_type> s_mask;
73+
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
6674
}
6775

6876
TEST(ArrayInterface, GetElement) {

0 commit comments

Comments
 (0)