Skip to content

Commit d4e67f9

Browse files
committed
fix copy bugs, add constexpr test
1 parent 195a922 commit d4e67f9

2 files changed

Lines changed: 96 additions & 39 deletions

File tree

include/experimental/p3242_bits/copy.hpp

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,27 @@
2424
#include <cstring>
2525
#include <functional>
2626
#include <utility>
27+
#include <algorithm>
2728

2829
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
2930
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
3031
namespace detail {
3132

3233
template <class Extents, class F, class ArrayType>
33-
constexpr void apply_fun_over_extents(const Extents &ext, F &fun,
34+
constexpr void apply_fun_over_extents(const Extents &ext, F &&fun,
3435
ArrayType &indices,
3536
std::index_sequence<>) {
36-
std::apply(fun, indices);
37+
std::apply(std::forward<F>(fun), indices);
3738
}
3839

3940
template <class Extents, class F, class ArrayType, size_t R, size_t... Ranks>
40-
constexpr void apply_fun_over_extents(const Extents &ext, F &fun,
41+
constexpr void apply_fun_over_extents(const Extents &ext, F &&fun,
4142
ArrayType &indices,
4243
std::index_sequence<R, Ranks...>) {
4344
using index_type = typename Extents::index_type;
4445
for (index_type i = 0; i < ext.extent(R); ++i) {
4546
indices[R] = i;
46-
apply_fun_over_extents(ext, fun, indices, std::index_sequence<Ranks...>{});
47+
apply_fun_over_extents(ext, std::forward<F>(fun), indices, std::index_sequence<Ranks...>{});
4748
}
4849
}
4950

@@ -58,7 +59,7 @@ template <size_t N>
5859
using make_reverse_index_sequence = typename make_reverse_index_sequence_impl<
5960
N, std::make_index_sequence<N>>::type;
6061

61-
template <class SrcMDSpanType, class DstMDSpanType, typename Enabled = void>
62+
template <class SrcMDSpanType, class DstMDSpanType>
6263
struct mdspan_copy_impl {
6364
using extents_type = typename DstMDSpanType::extents_type;
6465

@@ -71,16 +72,16 @@ struct mdspan_copy_impl {
7172
constexpr auto rank = extents_type::rank();
7273
auto indices = std::array<typename extents_type::index_type, rank>{};
7374
apply_fun_over_extents(
74-
ext, [&src, &dst](auto... idxs) { dst(idxs...) = src(idxs...); },
75-
indices, make_reverse_index_sequence<rank>{});
75+
ext, [&src, &dst](auto... idxs) { dst[idxs...] = src[idxs...]; },
76+
indices, std::make_index_sequence<rank>{});
7677
}
7778
};
7879

79-
template <class ElementType, class SrcExtents, class DstExtents>
80+
template <class ElementType, class SrcExtents, class SrcLayout, class DstExtents, class DstLayout>
81+
requires (SrcLayout::template mapping<SrcExtents>::is_always_exhaustive() && DstLayout::template mapping<DstExtents>::is_always_exhaustive())
8082
struct mdspan_copy_impl<
81-
mdspan<ElementType, SrcExtents, layout_left, default_accessor<ElementType>>,
82-
mdspan<ElementType, DstExtents, layout_left, default_accessor<ElementType>>,
83-
void> {
83+
mdspan<ElementType, SrcExtents, SrcLayout>,
84+
mdspan<ElementType, DstExtents, DstLayout>> {
8485
using extents_type = DstExtents;
8586
using src_mdspan_type = mdspan<ElementType, SrcExtents, layout_left,
8687
default_accessor<ElementType>>;
@@ -90,33 +91,15 @@ struct mdspan_copy_impl<
9091
static constexpr void copy_over_extents(const extents_type &ext,
9192
const src_mdspan_type &src,
9293
const dst_mdspan_type &dst) {
93-
std::memcpy(dst.data_handle(), src.data_handle(), dst.mapping().required_span_size() * sizeof(ElementType));
94-
}
95-
};
96-
97-
template <class ElementType, class SrcExtents, class DstExtents>
98-
struct mdspan_copy_impl<
99-
mdspan<ElementType, SrcExtents, layout_right, default_accessor<ElementType>>,
100-
mdspan<ElementType, DstExtents, layout_right, default_accessor<ElementType>>,
101-
void> {
102-
using extents_type = DstExtents;
103-
using src_mdspan_type = mdspan<ElementType, SrcExtents, layout_left,
104-
default_accessor<ElementType>>;
105-
using dst_mdspan_type = mdspan<ElementType, DstExtents, layout_left,
106-
default_accessor<ElementType>>;
107-
108-
static constexpr void copy_over_extents(const extents_type &ext,
109-
const src_mdspan_type &src,
110-
const dst_mdspan_type &dst) {
111-
std::memcpy(dst.data_handle(), src.data_handle(), dst.mapping().required_span_size() * sizeof(ElementType));
94+
std::copy(src.data_handle(), src.data_handle() + src.mapping().required_span_size(), dst.data_handle());
11295
}
11396
};
11497
} // namespace detail
11598

11699
template <class SrcElementType, class SrcExtents, class SrcLayoutPolicy,
117100
class SrcAccessorPolicy, class DstElementType, class DstExtents,
118101
class DstLayoutPolicy, class DstAccessorPolicy>
119-
void copy(
102+
constexpr void copy(
120103
mdspan<SrcElementType, SrcExtents, SrcLayoutPolicy, SrcAccessorPolicy> src,
121104
mdspan<DstElementType, DstExtents, DstLayoutPolicy, DstAccessorPolicy>
122105
dst) {

tests/test_mdspan_copy_and_fill.cpp

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,93 @@
1616

1717
#include <mdspan/mdarray.hpp>
1818
#include <mdspan/mdspan.hpp>
19-
19+
#include <array>
2020
#include <gtest/gtest.h>
2121

22+
template <class T, class Layout, class Extents>
23+
constexpr auto make_mdarray(const Extents &exts) {
24+
Kokkos::Experimental::mdarray<T, Extents, Layout> mds{exts};
25+
for (size_t i = 0; i < mds.mapping().required_span_size(); ++i)
26+
mds.data()[i] = static_cast<T>(i);
27+
28+
return mds;
29+
}
30+
31+
template <class T, class Extents, std::size_t... Indices>
32+
constexpr auto make_constexpr_array_impl(std::index_sequence<Indices...>) {
33+
constexpr auto sz = (Extents::static_extent(Indices) * ...);
34+
std::array<T, sz> ret;
35+
for (size_t i = 0; i < sz; ++i)
36+
ret[i] = static_cast<T>(i);
37+
38+
return ret;
39+
}
40+
41+
template <class Extents, std::size_t... Indices>
42+
constexpr auto array_size_impl(std::index_sequence<Indices...>) {
43+
return (Extents::static_extent(Indices) * ...);
44+
}
45+
46+
47+
template <class Extents>
48+
constexpr size_t array_size() {
49+
return array_size_impl<Extents>();
50+
}
51+
52+
53+
template <class T, class Extents>
54+
constexpr auto make_constexpr_array() {
55+
return make_constexpr_array_impl<T, Extents>(std::make_index_sequence<Extents::rank()>{});
56+
}
57+
58+
template <class T, class Extents, class Layout, class SrcExtents,
59+
class SrcLayout>
60+
constexpr auto make_mdarray_copy(const Kokkos::mdspan<T, SrcExtents, SrcLayout> &src) {
61+
Kokkos::Experimental::mdarray<T, Extents, Layout> dst{};
62+
Kokkos::Experimental::copy(src, dst.to_mdspan());
63+
64+
return dst;
65+
}
66+
67+
template <class T, class SrcLayout, class DstLayout, class SrcExtents,
68+
class DstExtents>
69+
constexpr bool test_mdspan_copy_check(const SrcExtents &src_exts,
70+
const DstExtents &dst_exts) {
71+
Kokkos::Experimental::mdarray<T, SrcExtents, SrcLayout> src1 =
72+
make_mdarray<T, SrcLayout>(src_exts);
73+
Kokkos::Experimental::mdarray<T, DstExtents, DstLayout> dst1{dst_exts};
74+
75+
if (dst1.container() == src1.container()) return false;
76+
Kokkos::Experimental::copy(src1.to_mdspan(), dst1.to_mdspan());
77+
return dst1.container() == src1.container();
78+
}
79+
2280
template <class T, class SrcLayout, class DstLayout, class SrcExtents,
2381
class DstExtents>
2482
void test_mdspan_copy_impl(const SrcExtents &src_exts,
2583
const DstExtents &dst_exts) {
26-
Kokkos::Experimental::mdarray<T, SrcExtents, SrcLayout> src1{src_exts};
27-
Kokkos::Experimental::mdarray<T, DstExtents, DstLayout> dst1{dst_exts};
28-
auto &src1c = src1.container();
29-
for (size_t i = 0; i < src1c.size(); ++i)
30-
src1c[i] = static_cast<T>(i * i);
84+
ASSERT_TRUE(
85+
(test_mdspan_copy_check<T, SrcLayout, DstLayout>(src_exts, dst_exts)));
86+
}
87+
88+
template <class T, class SrcExtents, class SrcLayout, class DstExtents, class DstLayout>
89+
constexpr bool test_mdspan_copy_constexpr_impl() {
90+
auto arr = make_constexpr_array<T, SrcExtents>();
91+
Kokkos::Experimental::mdarray<T, SrcExtents, SrcLayout, decltype(arr)> src1(SrcExtents{}, arr);
92+
Kokkos::Experimental::mdarray<T, DstExtents, DstLayout> dst1{};
3193

32-
ASSERT_NE(dst1.container(), src1.container());
3394
Kokkos::Experimental::copy(src1.to_mdspan(), dst1.to_mdspan());
34-
ASSERT_EQ(dst1.container(), src1.container());
95+
96+
for ( std::size_t i = 0; i < arr.size(); ++i )
97+
if (dst1.container()[i] != src1.container()[i])
98+
return false;
99+
100+
return true;
101+
}
102+
103+
template <class T, class SrcExtents, class SrcLayout, class DstExtents, class DstLayout>
104+
void test_mdspan_copy_constexpr() {
105+
static_assert(test_mdspan_copy_constexpr_impl<T, SrcExtents, SrcLayout, DstExtents, DstLayout>());
35106
}
36107

37108
TEST(TestMdspanCopyAndFill, test_mdspan_copy) {
@@ -62,4 +133,7 @@ TEST(TestMdspanCopyAndFill, test_mdspan_copy) {
62133
Kokkos::dextents<size_t, 2>{5, 3}, Kokkos::dextents<size_t, 2>{5, 3});
63134
test_mdspan_copy_impl<float, Kokkos::layout_left, Kokkos::layout_left>(
64135
Kokkos::dextents<size_t, 2>{5, 3}, Kokkos::dextents<size_t, 2>{5, 3});
136+
137+
test_mdspan_copy_constexpr<int, Kokkos::extents<size_t, 5, 3>, Kokkos::layout_left,
138+
Kokkos::extents<size_t, 5, 3>, Kokkos::layout_left>();
65139
}

0 commit comments

Comments
 (0)