6
6
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
7
7
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
8
8
9
- #include < algorithm>
10
- #include < cstddef> // for size_t
11
- #include < cstdint>
12
- #include < limits> // for numeric_limits
13
- #include < map>
14
- #include < string>
9
+ #include < algorithm> // for all_of, transform, fill
10
+ #include < cstddef> // for size_t
11
+ #include < cstdint> // for int32_t, int64_t, ...
12
+ #include < limits> // for numeric_limits
13
+ #include < map> // for map
14
+ #include < string> // for string
15
15
#include < type_traits> // for alignment_of, remove_pointer_t, invoke_result_t
16
- #include < utility>
17
- #include < vector>
16
+ #include < vector> // for vector
18
17
19
- #include " ../common/bitfield.h" // for RBitField8
20
- #include " ../common/common.h"
18
+ #include " ../common/bitfield.h" // for RBitField8
21
19
#include " ../common/error_msg.h" // for NoF128
22
- #include " xgboost/base.h"
23
- #include " xgboost/data.h"
24
- #include " xgboost/json.h"
25
- #include " xgboost/linalg.h"
26
- #include " xgboost/logging.h"
27
- #include " xgboost/span.h"
20
+ #include " xgboost/json.h" // for Json
21
+ #include " xgboost/linalg.h" // for CalcStride, TensorView
22
+ #include " xgboost/logging.h" // for CHECK
23
+ #include " xgboost/span.h" // for Span
24
+ #include " xgboost/string_view.h" // for StringView
28
25
29
26
#if defined(XGBOOST_USE_CUDA)
30
- #include " cuda_fp16.h"
27
+ #include " cuda_fp16.h" // for __half
31
28
#endif
32
29
33
30
namespace xgboost {
@@ -410,7 +407,7 @@ class ArrayInterface {
410
407
auto typestr = get<String const >(array.at (" typestr" ));
411
408
this ->AssignType (StringView{typestr});
412
409
ArrayInterfaceHandler::ExtractShape (array, shape);
413
- size_t itemsize = typestr[2 ] - ' 0' ;
410
+ std:: size_t itemsize = typestr[2 ] - ' 0' ;
414
411
is_contiguous = ArrayInterfaceHandler::ExtractStride (array, itemsize, shape, strides);
415
412
n = linalg::detail::CalcSize (shape);
416
413
@@ -419,7 +416,9 @@ class ArrayInterface {
419
416
420
417
auto alignment = this ->ElementAlignment ();
421
418
auto ptr = reinterpret_cast <uintptr_t >(this ->data );
422
- CHECK_EQ (ptr % alignment, 0 ) << " Input pointer misalignment." ;
419
+ if (!std::all_of (this ->shape , this ->shape + D, [](auto v) { return v == 0 ; })) {
420
+ CHECK_EQ (ptr % alignment, 0 ) << " Input pointer misalignment." ;
421
+ }
423
422
424
423
if (allow_mask) {
425
424
common::Span<RBitField8::value_type> s_mask;
0 commit comments