@@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
101
101
template <typename PtrType>
102
102
static PtrType GetPtrFromArrayData (Object::Map const &obj) {
103
103
auto data_it = obj.find (" data" );
104
- if (data_it == obj.cend ()) {
104
+ if (data_it == obj.cend () || IsA<Null>(data_it-> second ) ) {
105
105
LOG (FATAL) << " Empty data passed in." ;
106
106
}
107
107
auto p_data = reinterpret_cast <PtrType>(
@@ -111,25 +111,27 @@ class ArrayInterfaceHandler {
111
111
112
112
static void Validate (Object::Map const &array) {
113
113
auto version_it = array.find (" version" );
114
- if (version_it == array.cend ()) {
114
+ if (version_it == array.cend () || IsA<Null>(version_it-> second ) ) {
115
115
LOG (FATAL) << " Missing `version' field for array interface" ;
116
116
}
117
117
if (get<Integer const >(version_it->second ) > 3 ) {
118
118
LOG (FATAL) << ArrayInterfaceErrors::Version ();
119
119
}
120
120
121
121
auto typestr_it = array.find (" typestr" );
122
- if (typestr_it == array.cend ()) {
122
+ if (typestr_it == array.cend () || IsA<Null>(typestr_it-> second ) ) {
123
123
LOG (FATAL) << " Missing `typestr' field for array interface" ;
124
124
}
125
125
126
126
auto typestr = get<String const >(typestr_it->second );
127
127
CHECK (typestr.size () == 3 || typestr.size () == 4 ) << ArrayInterfaceErrors::TypestrFormat ();
128
128
129
- if (array.find (" shape" ) == array.cend ()) {
129
+ auto shape_it = array.find (" shape" );
130
+ if (shape_it == array.cend () || IsA<Null>(shape_it->second )) {
130
131
LOG (FATAL) << " Missing `shape' field for array interface" ;
131
132
}
132
- if (array.find (" data" ) == array.cend ()) {
133
+ auto data_it = array.find (" data" );
134
+ if (data_it == array.cend () || IsA<Null>(data_it->second )) {
133
135
LOG (FATAL) << " Missing `data' field for array interface" ;
134
136
}
135
137
}
@@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
139
141
static size_t ExtractMask (Object::Map const &column,
140
142
common::Span<RBitField8::value_type> *p_out) {
141
143
auto &s_mask = *p_out;
142
- if (column.find (" mask" ) != column.cend ()) {
143
- auto const &j_mask = get<Object const >(column.at (" mask" ));
144
+ auto const &mask_it = column.find (" mask" );
145
+ if (mask_it != column.cend () && !IsA<Null>(mask_it->second )) {
146
+ auto const &j_mask = get<Object const >(mask_it->second );
144
147
Validate (j_mask);
145
148
146
149
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
@@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
173
176
// assume 1 byte alignment.
174
177
size_t const span_size = RBitField8::ComputeStorageSize (n_bits);
175
178
176
- if (j_mask.find (" strides" ) != j_mask.cend ()) {
177
- auto strides = get<Array const >(column.at (" strides" ));
179
+ auto strides_it = j_mask.find (" strides" );
180
+ if (strides_it != j_mask.cend () && !IsA<Null>(strides_it->second )) {
181
+ auto strides = get<Array const >(strides_it->second );
178
182
CHECK_EQ (strides.size (), 1 ) << ArrayInterfaceErrors::Dimension (1 );
179
183
CHECK_EQ (get<Integer>(strides.at (0 )), type_length) << ArrayInterfaceErrors::Contiguous ();
180
184
}
@@ -401,7 +405,9 @@ class ArrayInterface {
401
405
<< " XGBoost doesn't support internal broadcasting." ;
402
406
}
403
407
} else {
404
- CHECK (array.find (" mask" ) == array.cend ()) << " Masked array is not yet supported." ;
408
+ auto mask_it = array.find (" mask" );
409
+ CHECK (mask_it == array.cend () || IsA<Null>(mask_it->second ))
410
+ << " Masked array is not yet supported." ;
405
411
}
406
412
407
413
auto stream_it = array.find (" stream" );
0 commit comments