Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit 5f3ded6

Browse files
authored
Implement XTensor support in core. (#976)
* Implement XTensor support in core. Adds support for `xt::xtensor`, `xt::xarray` and `xt::xview`, both row and column major. This works by wrapping the internal row-major with `xt::adapt`. Therefore, the `T` in `xt::xtensor<T, ...>` must be scalar (trivial). * In tests/examples: Define `NOMINMAX` for Windows.
1 parent 7498b80 commit 5f3ded6

File tree

6 files changed

+480
-6
lines changed

6 files changed

+480
-6
lines changed

include/highfive/xtensor.hpp

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#pragma once
2+
3+
#include "bits/H5Inspector_decl.hpp"
4+
#include "H5Exception.hpp"
5+
6+
#include <xtensor/xtensor.hpp>
7+
#include <xtensor/xarray.hpp>
8+
#include <xtensor/xadapt.hpp>
9+
10+
namespace HighFive {
11+
namespace details {
12+
13+
template <class XTensor>
14+
struct xtensor_get_rank;
15+
16+
template <typename T, size_t N, xt::layout_type L>
17+
struct xtensor_get_rank<xt::xtensor<T, N, L>> {
18+
static constexpr size_t value = N;
19+
};
20+
21+
template <class EC, size_t N, xt::layout_type L, class Tag>
22+
struct xtensor_get_rank<xt::xtensor_adaptor<EC, N, L, Tag>> {
23+
static constexpr size_t value = N;
24+
};
25+
26+
template <class Derived, class XTensorType, xt::layout_type L>
27+
struct xtensor_inspector_base {
28+
using type = XTensorType;
29+
using value_type = typename type::value_type;
30+
using base_type = typename inspector<value_type>::base_type;
31+
using hdf5_type = base_type;
32+
33+
static_assert(std::is_same<value_type, base_type>::value,
34+
"HighFive's XTensor support only works for scalar elements.");
35+
36+
static constexpr bool IsConstExprRowMajor = L == xt::layout_type::row_major;
37+
static constexpr bool is_trivially_copyable = IsConstExprRowMajor &&
38+
std::is_trivially_copyable<value_type>::value &&
39+
inspector<value_type>::is_trivially_copyable;
40+
41+
static constexpr bool is_trivially_nestable = false;
42+
43+
static size_t getRank(const type& val) {
44+
// Non-scalar elements are not supported.
45+
return val.shape().size();
46+
}
47+
48+
static const value_type& getAnyElement(const type& val) {
49+
return val.unchecked(0);
50+
}
51+
52+
static value_type& getAnyElement(type& val) {
53+
return val.unchecked(0);
54+
}
55+
56+
static std::vector<size_t> getDimensions(const type& val) {
57+
auto shape = val.shape();
58+
return {shape.begin(), shape.end()};
59+
}
60+
61+
static void prepare(type& val, const std::vector<size_t>& dims) {
62+
val.resize(Derived::shapeFromDims(dims));
63+
}
64+
65+
static hdf5_type* data(type& val) {
66+
if (!is_trivially_copyable) {
67+
throw DataSetException("Invalid used of `inspector<XTensor>::data`.");
68+
}
69+
70+
if (val.size() == 0) {
71+
return nullptr;
72+
}
73+
74+
return inspector<value_type>::data(getAnyElement(val));
75+
}
76+
77+
static const hdf5_type* data(const type& val) {
78+
if (!is_trivially_copyable) {
79+
throw DataSetException("Invalid used of `inspector<XTensor>::data`.");
80+
}
81+
82+
if (val.size() == 0) {
83+
return nullptr;
84+
}
85+
86+
return inspector<value_type>::data(getAnyElement(val));
87+
}
88+
89+
static void serialize(const type& val, const std::vector<size_t>& dims, hdf5_type* m) {
90+
// since we only support scalar types we know all dims belong to us.
91+
size_t size = compute_total_size(dims);
92+
xt::adapt(m, size, xt::no_ownership(), dims) = val;
93+
}
94+
95+
static void unserialize(const hdf5_type* vec_align,
96+
const std::vector<size_t>& dims,
97+
type& val) {
98+
// since we only support scalar types we know all dims belong to us.
99+
size_t size = compute_total_size(dims);
100+
val = xt::adapt(vec_align, size, xt::no_ownership(), dims);
101+
}
102+
};
103+
104+
template <class XTensorType, xt::layout_type L>
105+
struct xtensor_inspector
106+
: public xtensor_inspector_base<xtensor_inspector<XTensorType, L>, XTensorType, L> {
107+
private:
108+
using super = xtensor_inspector_base<xtensor_inspector<XTensorType, L>, XTensorType, L>;
109+
110+
public:
111+
using type = typename super::type;
112+
using value_type = typename super::value_type;
113+
using base_type = typename super::base_type;
114+
using hdf5_type = typename super::hdf5_type;
115+
116+
static constexpr size_t ndim = xtensor_get_rank<XTensorType>::value;
117+
static constexpr size_t min_ndim = ndim + inspector<value_type>::min_ndim;
118+
static constexpr size_t max_ndim = ndim + inspector<value_type>::max_ndim;
119+
120+
static std::array<size_t, ndim> shapeFromDims(const std::vector<size_t>& dims) {
121+
std::array<size_t, ndim> shape;
122+
std::copy(dims.cbegin(), dims.cend(), shape.begin());
123+
return shape;
124+
}
125+
};
126+
127+
template <class XArrayType, xt::layout_type L>
128+
struct xarray_inspector
129+
: public xtensor_inspector_base<xarray_inspector<XArrayType, L>, XArrayType, L> {
130+
private:
131+
using super = xtensor_inspector_base<xarray_inspector<XArrayType, L>, XArrayType, L>;
132+
133+
public:
134+
using type = typename super::type;
135+
using value_type = typename super::value_type;
136+
using base_type = typename super::base_type;
137+
using hdf5_type = typename super::hdf5_type;
138+
139+
static constexpr size_t min_ndim = 0 + inspector<value_type>::min_ndim;
140+
static constexpr size_t max_ndim = 1024 + inspector<value_type>::max_ndim;
141+
142+
static const std::vector<size_t>& shapeFromDims(const std::vector<size_t>& dims) {
143+
return dims;
144+
}
145+
};
146+
147+
template <typename T, size_t N, xt::layout_type L>
148+
struct inspector<xt::xtensor<T, N, L>>: public xtensor_inspector<xt::xtensor<T, N, L>, L> {
149+
private:
150+
using super = xtensor_inspector<xt::xtensor<T, N, L>, L>;
151+
152+
public:
153+
using type = typename super::type;
154+
using value_type = typename super::value_type;
155+
using base_type = typename super::base_type;
156+
using hdf5_type = typename super::hdf5_type;
157+
};
158+
159+
template <typename T, xt::layout_type L>
160+
struct inspector<xt::xarray<T, L>>: public xarray_inspector<xt::xarray<T, L>, L> {
161+
private:
162+
using super = xarray_inspector<xt::xarray<T, L>, L>;
163+
164+
public:
165+
using type = typename super::type;
166+
using value_type = typename super::value_type;
167+
using base_type = typename super::base_type;
168+
using hdf5_type = typename super::hdf5_type;
169+
};
170+
171+
template <typename CT, class... S>
172+
struct inspector<xt::xview<CT, S...>>
173+
: public xarray_inspector<xt::xview<CT, S...>, xt::layout_type::any> {
174+
private:
175+
using super = xarray_inspector<xt::xview<CT, S...>, xt::layout_type::any>;
176+
177+
public:
178+
using type = typename super::type;
179+
using value_type = typename super::value_type;
180+
using base_type = typename super::base_type;
181+
using hdf5_type = typename super::hdf5_type;
182+
};
183+
184+
185+
template <class EC, xt::layout_type L, class SC, class Tag>
186+
struct inspector<xt::xarray_adaptor<EC, L, SC, Tag>>
187+
: public xarray_inspector<xt::xarray_adaptor<EC, L, SC, Tag>, xt::layout_type::any> {
188+
private:
189+
using super = xarray_inspector<xt::xarray_adaptor<EC, L, SC, Tag>, xt::layout_type::any>;
190+
191+
public:
192+
using type = typename super::type;
193+
using value_type = typename super::value_type;
194+
using base_type = typename super::base_type;
195+
using hdf5_type = typename super::hdf5_type;
196+
};
197+
198+
template <class EC, size_t N, xt::layout_type L, class Tag>
199+
struct inspector<xt::xtensor_adaptor<EC, N, L, Tag>>
200+
: public xtensor_inspector<xt::xtensor_adaptor<EC, N, L, Tag>, xt::layout_type::any> {
201+
private:
202+
using super = xtensor_inspector<xt::xtensor_adaptor<EC, N, L, Tag>, xt::layout_type::any>;
203+
204+
public:
205+
using type = typename super::type;
206+
using value_type = typename super::value_type;
207+
using base_type = typename super::base_type;
208+
using hdf5_type = typename super::hdf5_type;
209+
};
210+
211+
} // namespace details
212+
} // namespace HighFive

tests/unit/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ if(MSVC)
66
endif()
77

88
## Base tests
9-
foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string)
9+
foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string test_xtensor)
1010
add_executable(${test_name} "${test_name}.cpp")
1111
target_link_libraries(${test_name} HighFive HighFiveWarnings HighFiveFlags Catch2::Catch2WithMain)
1212
target_link_libraries(${test_name} HighFiveOptionalDependencies)
@@ -47,7 +47,7 @@ endif()
4747
# test succeeds if it compiles.
4848
file(GLOB public_headers LIST_DIRECTORIES false RELATIVE ${PROJECT_SOURCE_DIR}/include CONFIGURE_DEPENDS ${PROJECT_SOURCE_DIR}/include/highfive/*.hpp)
4949
foreach(PUBLIC_HEADER ${public_headers})
50-
if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN)
50+
if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN)
5151
continue()
5252
endif()
5353

@@ -67,6 +67,10 @@ foreach(PUBLIC_HEADER ${public_headers})
6767
continue()
6868
endif()
6969

70+
if(PUBLIC_HEADER STREQUAL "highfive/xtensor.hpp" AND NOT HIGHFIVE_TEST_XTENSOR)
71+
continue()
72+
endif()
73+
7074
get_filename_component(CLASS_NAME ${PUBLIC_HEADER} NAME_WE)
7175
configure_file(tests_import_public_headers.cpp "tests_${CLASS_NAME}.cpp" @ONLY)
7276
add_executable("tests_include_${CLASS_NAME}" "${CMAKE_CURRENT_BINARY_DIR}/tests_${CLASS_NAME}.cpp")

tests/unit/data_generator.hpp

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121
#include <highfive/span.hpp>
2222
#endif
2323

24+
#ifdef HIGHFIVE_TEST_XTENSOR
25+
#include <highfive/xtensor.hpp>
26+
#endif
27+
2428

2529
namespace HighFive {
2630
namespace testing {
2731

28-
std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
32+
template <class Dims>
33+
std::vector<size_t> lstrip(const Dims& indices, size_t n) {
2934
std::vector<size_t> subindices(indices.size() - n);
3035
for (size_t i = 0; i < subindices.size(); ++i) {
3136
subindices[i] = indices[i + n];
@@ -34,7 +39,8 @@ std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
3439
return subindices;
3540
}
3641

37-
size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
42+
template <class Dims>
43+
size_t ravel(std::vector<size_t>& indices, const Dims& dims) {
3844
size_t rank = dims.size();
3945
size_t linear_index = 0;
4046
size_t ld = 1;
@@ -47,7 +53,8 @@ size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
4753
return linear_index;
4854
}
4955

50-
std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
56+
template <class Dims>
57+
std::vector<size_t> unravel(size_t flat_index, const Dims& dims) {
5158
size_t rank = dims.size();
5259
size_t ld = 1;
5360
std::vector<size_t> indices(rank);
@@ -60,7 +67,8 @@ std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
6067
return indices;
6168
}
6269

63-
static size_t flat_size(const std::vector<size_t>& dims) {
70+
template <class Dims>
71+
static size_t flat_size(const Dims& dims) {
6472
size_t n = 1;
6573
for (auto d: dims) {
6674
n *= d;
@@ -388,6 +396,7 @@ struct ContainerTraits<boost::numeric::ublas::matrix<T>> {
388396

389397
#endif
390398

399+
// -- Eigen -------------------------------------------------------------------
391400
#if HIGHFIVE_TEST_EIGEN
392401

393402
template <typename EigenType>
@@ -525,6 +534,88 @@ struct ContainerTraits<Eigen::Map<PlainObjectType, MapOptions>>
525534
};
526535

527536

537+
#endif
538+
539+
// -- XTensor -----------------------------------------------------------------
540+
541+
#if HIGHFIVE_TEST_XTENSOR
542+
template <typename XTensorType, size_t Rank>
543+
struct XTensorContainerTraits {
544+
using container_type = XTensorType;
545+
using value_type = typename container_type::value_type;
546+
using base_type = typename ContainerTraits<value_type>::base_type;
547+
548+
static constexpr size_t rank = Rank;
549+
static constexpr bool is_view = ContainerTraits<value_type>::is_view;
550+
551+
static void set(container_type& array,
552+
const std::vector<size_t>& indices,
553+
const base_type& value) {
554+
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
555+
return ContainerTraits<value_type>::set(array[local_indices], lstrip(indices, rank), value);
556+
}
557+
558+
static base_type get(const container_type& array, const std::vector<size_t>& indices) {
559+
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
560+
return ContainerTraits<value_type>::get(array[local_indices], lstrip(indices, rank));
561+
}
562+
563+
static void assign(container_type& dst, const container_type& src) {
564+
dst = src;
565+
}
566+
567+
static container_type allocate(const std::vector<size_t>& dims) {
568+
const auto& local_dims = details::inspector<XTensorType>::shapeFromDims(dims);
569+
auto array = container_type(local_dims);
570+
571+
size_t n_elements = flat_size(local_dims);
572+
for (size_t i = 0; i < n_elements; ++i) {
573+
auto element = ContainerTraits<value_type>::allocate(lstrip(dims, rank));
574+
set(array, unravel(i, local_dims), element);
575+
}
576+
577+
return array;
578+
}
579+
580+
static void deallocate(container_type& array, const std::vector<size_t>& dims) {
581+
auto local_dims = std::vector<size_t>(dims.begin(), dims.begin() + rank);
582+
size_t n_elements = flat_size(local_dims);
583+
for (size_t i_flat = 0; i_flat < n_elements; ++i_flat) {
584+
auto indices = unravel(i_flat, local_dims);
585+
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
586+
ContainerTraits<value_type>::deallocate(array[local_indices], lstrip(dims, rank));
587+
}
588+
}
589+
590+
static void sanitize_dims(std::vector<size_t>& dims, size_t axis) {
591+
ContainerTraits<value_type>::sanitize_dims(dims, axis + rank);
592+
}
593+
};
594+
595+
template <class T, size_t rank, xt::layout_type layout>
596+
struct ContainerTraits<xt::xtensor<T, rank, layout>>
597+
: public XTensorContainerTraits<xt::xtensor<T, rank, layout>, rank> {
598+
private:
599+
using super = XTensorContainerTraits<xt::xtensor<T, rank, layout>, rank>;
600+
601+
public:
602+
using container_type = typename super::container_type;
603+
using value_type = typename super::value_type;
604+
using base_type = typename super::base_type;
605+
};
606+
607+
template <class T, xt::layout_type layout>
608+
struct ContainerTraits<xt::xarray<T, layout>>
609+
: public XTensorContainerTraits<xt::xarray<T, layout>, 2> {
610+
private:
611+
using super = XTensorContainerTraits<xt::xarray<T, layout>, 2>;
612+
613+
public:
614+
using container_type = typename super::container_type;
615+
using value_type = typename super::value_type;
616+
using base_type = typename super::base_type;
617+
};
618+
528619
#endif
529620

530621
template <class T, class C>

0 commit comments

Comments
 (0)