Skip to content

Commit 6094106

Browse files
trivialfishcho3
andauthored
[backport] Allow unaligned pointer if the array is empty. (dmlc#10418) (dmlc#10424)
Co-authored-by: Philip Hyunsu Cho <[email protected]>
1 parent f789e50 commit 6094106

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

src/data/array_interface.h

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,25 @@
66
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
77
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
88

9-
#include <algorithm>
10-
#include <cstddef> // for size_t
11-
#include <cstdint>
12-
#include <limits> // for numeric_limits
13-
#include <map>
14-
#include <string>
9+
#include <algorithm> // for all_of, transform, fill
10+
#include <cstddef> // for size_t
11+
#include <cstdint> // for int32_t, int64_t, ...
12+
#include <limits> // for numeric_limits
13+
#include <map> // for map
14+
#include <string> // for string
1515
#include <type_traits> // for alignment_of, remove_pointer_t, invoke_result_t
16-
#include <utility>
17-
#include <vector>
16+
#include <vector> // for vector
1817

19-
#include "../common/bitfield.h" // for RBitField8
20-
#include "../common/common.h"
18+
#include "../common/bitfield.h" // for RBitField8
2119
#include "../common/error_msg.h" // for NoF128
22-
#include "xgboost/base.h"
23-
#include "xgboost/data.h"
24-
#include "xgboost/json.h"
25-
#include "xgboost/linalg.h"
26-
#include "xgboost/logging.h"
27-
#include "xgboost/span.h"
20+
#include "xgboost/json.h" // for Json
21+
#include "xgboost/linalg.h" // for CalcStride, TensorView
22+
#include "xgboost/logging.h" // for CHECK
23+
#include "xgboost/span.h" // for Span
24+
#include "xgboost/string_view.h" // for StringView
2825

2926
#if defined(XGBOOST_USE_CUDA)
30-
#include "cuda_fp16.h"
27+
#include "cuda_fp16.h" // for __half
3128
#endif
3229

3330
namespace xgboost {
@@ -410,7 +407,7 @@ class ArrayInterface {
410407
auto typestr = get<String const>(array.at("typestr"));
411408
this->AssignType(StringView{typestr});
412409
ArrayInterfaceHandler::ExtractShape(array, shape);
413-
size_t itemsize = typestr[2] - '0';
410+
std::size_t itemsize = typestr[2] - '0';
414411
is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides);
415412
n = linalg::detail::CalcSize(shape);
416413

@@ -419,7 +416,9 @@ class ArrayInterface {
419416

420417
auto alignment = this->ElementAlignment();
421418
auto ptr = reinterpret_cast<uintptr_t>(this->data);
422-
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
419+
if (!std::all_of(this->shape, this->shape + D, [](auto v) { return v == 0; })) {
420+
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
421+
}
423422

424423
if (allow_mask) {
425424
common::Span<RBitField8::value_type> s_mask;

0 commit comments

Comments
 (0)