Skip to content

Commit 7747a8c

Browse files
author
Fikret Ardal
committed
merge with unstable branch
2 parents 084969b + ec439ef commit 7747a8c

21 files changed

+272
-67
lines changed

benchmarks/constr_heap_vs_sso.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct toy_mat1_t {
4242

4343
struct toy_mat2_t {
4444
sso<1000>::handle<double> storage;
45-
toy_mat2_t(){}; // Custom constructor -> Not an aggregate type
45+
toy_mat2_t() {}; // Custom constructor -> Not an aggregate type
4646
};
4747

4848
BENCH_EXPR(constr_toy_mat1, toy_mat1_t{});

c++/nda/_impl_basic_array_view_common.hpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -525,19 +525,39 @@ void assign_from_ndarray(RHS const &rhs) { // FIXME noexcept {
525525
template <typename Scalar>
526526
void fill_with_scalar(Scalar const &scalar) noexcept {
527527
// we make a special implementation if the array is strided in 1d or contiguous
528-
if constexpr (has_layout_strided_1d<self_t>) {
529-
const long L = indexmap().size();
530-
auto *__restrict const p = data(); // no alias possible here!
531-
if constexpr (has_contiguous_layout<self_t>) {
532-
for (long i = 0; i < L; ++i) p[i] = scalar;
528+
if constexpr (mem::on_host<self_t>) {
529+
if constexpr (has_layout_strided_1d<self_t>) {
530+
const long L = size();
531+
auto *__restrict const p = data(); // no alias possible here!
532+
if constexpr (has_contiguous_layout<self_t>) {
533+
for (long i = 0; i < L; ++i) p[i] = scalar;
534+
} else {
535+
const long stri = indexmap().min_stride();
536+
const long Lstri = L * stri;
537+
for (long i = 0; i != Lstri; i += stri) p[i] = scalar;
538+
}
533539
} else {
534-
const long stri = indexmap().min_stride();
535-
const long Lstri = L * stri;
536-
for (long i = 0; i != Lstri; i += stri) p[i] = scalar;
540+
for (auto &x : *this) x = scalar;
541+
}
542+
} else if constexpr (mem::on_device<self_t> or mem::on_unified<self_t>) { // on device
543+
if constexpr (has_layout_strided_1d<self_t>) { // possibly contiguous
544+
if constexpr (has_contiguous_layout<self_t>) {
545+
mem::fill_n<mem::get_addr_space<self_t>>(data(), size(), value_type(scalar));
546+
} else {
547+
const long stri = indexmap().min_stride();
548+
mem::fill2D_n<mem::get_addr_space<self_t>>(data(), stri, 1, size(), value_type(scalar));
549+
}
550+
} else {
551+
// check for 2D layout
552+
auto bl_layout = get_block_layout(*this);
553+
if (bl_layout) {
554+
auto [n_bl, bl_size, bl_str] = *bl_layout;
555+
mem::fill2D_n<mem::get_addr_space<self_t>>(data(), bl_str, bl_size, n_bl, value_type(scalar));
556+
} else {
557+
// MAM: implement recursive call to fill_with_scalar on (i,nda::ellipsis{})
558+
NDA_RUNTIME_ERROR << "fill_with_scalar: Not implemented yet for general layout. ";
559+
}
537560
}
538-
} else {
539-
// no compile-time memory layout guarantees
540-
for (auto &x : *this) x = scalar;
541561
}
542562
}
543563

@@ -556,7 +576,6 @@ void assign_from_scalar(Scalar const &scalar) noexcept {
556576
fill_with_scalar(0);
557577
else
558578
fill_with_scalar(Scalar{0 * scalar}); // FIXME : improve this
559-
const long imax = std::min(extent(0), extent(1));
560-
for (long i = 0; i < imax; ++i) operator()(i, i) = scalar;
579+
diagonal(*this).fill_with_scalar(scalar);
561580
}
562581
}

c++/nda/basic_array_view.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
#include "./declarations.hpp"
1717
#include "./exceptions.hpp"
1818
#include "./iterators.hpp"
19+
#include "layout/slice_static.hpp"
1920
#include "./layout/for_each.hpp"
2021
#include "./layout/idx_map.hpp"
2122
#include "./layout/permutation.hpp"
2223
#include "./layout/range.hpp"
2324
#include "./macros.hpp"
2425
#include "./mem/address_space.hpp"
2526
#include "./mem/memcpy.hpp"
27+
#include "./mem/memset.hpp"
28+
#include "./mem/fill.hpp"
2629
#include "./mem/policies.hpp"
2730
#include "./traits.hpp"
2831
#include "./config.hpp"

c++/nda/basic_functions.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ namespace nda {
605605
// slicing helper function
606606
auto slice_Axis = [](Array auto &a, range r) {
607607
auto all_or_range = std::make_tuple(range::all, r);
608-
return [&]<auto... Is>(std::index_sequence<Is...>) { return a(std::get<Is == Axis>(all_or_range)...); }(std::make_index_sequence<rank>{});
608+
return [&]<auto... Is>(std::index_sequence<Is...>) { return a(std::get < Is == Axis > (all_or_range)...); }(std::make_index_sequence<rank>{});
609609
};
610610

611611
// initialize concatenated array

c++/nda/blas/gemm.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,15 @@ namespace nda::blas {
5050
EXPECTS(a.extent(1) == b.extent(0));
5151
EXPECTS(a.extent(0) == c.extent(0));
5252
EXPECTS(b.extent(1) == c.extent(1));
53+
54+
if (beta == 0.0) {
55+
c = 0 * alpha;
56+
} else {
57+
c *= beta;
58+
}
59+
5360
for (int i = 0; i < a.extent(0); ++i) {
5461
for (int j = 0; j < b.extent(1); ++j) {
55-
c(i, j) = beta * c(i, j);
5662
for (int k = 0; k < a.extent(1); ++k) c(i, j) += alpha * a(i, k) * b(k, j);
5763
}
5864
}

c++/nda/blas/gemv.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ namespace nda::blas {
4646
void gemv_generic(get_value_t<A> alpha, A const &a, X const &x, get_value_t<A> beta, Y &&y) { // NOLINT (temporary views are allowed here)
4747
EXPECTS(a.extent(1) == x.extent(0));
4848
EXPECTS(a.extent(0) == y.extent(0));
49+
50+
if (beta == 0.0) {
51+
y = 0 * alpha;
52+
} else {
53+
y *= beta;
54+
}
55+
4956
for (int i = 0; i < a.extent(0); ++i) {
50-
y(i) = beta * y(i);
5157
for (int k = 0; k < a.extent(1); ++k) y(i) += alpha * a(i, k) * x(k);
5258
}
5359
}

c++/nda/clef/literals.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ namespace nda::clef::literals {
2020
*/
2121

2222
// Define literal placeholders starting from the end of the allowed index spectrum.
23-
#define PH(I) \
24-
(placeholder<63 - (I)> {})
23+
#define PH(I) (placeholder<63 - (I)>{})
2524

2625
/// Generic placeholder #1.
2726
constexpr auto i_ = PH(0);
@@ -58,6 +57,6 @@ namespace nda::clef::literals {
5857

5958
#undef PH
6059

61-
/** @} */
60+
/** @} */
6261

6362
} // namespace nda::clef::literals

c++/nda/clef/make_lazy.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,8 @@ namespace nda::clef {
122122
} \
123123
\
124124
template <typename... Args> \
125-
auto operator()(Args &&...args) && \
126-
requires(nda::clef::is_any_lazy<Args...>) \
127-
{ \
128-
return make_expr_call(std::move(*this), std::forward<Args>(args)...); \
129-
}
125+
auto operator()(Args &&...args) \
126+
&& requires(nda::clef::is_any_lazy<Args...>) { return make_expr_call(std::move(*this), std::forward<Args>(args)...); }
130127

131128
/** @} */
132129

c++/nda/gtest_tools.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ ::testing::AssertionResult array_almost_zero(X const &x) {
126126
nda::array<nda::get_value_t<X>, nda::get_rank<X>> x_reg = x;
127127

128128
constexpr double eps = 1.e-10;
129-
const auto max = max_element(abs(x_reg));
129+
const auto max = max_element(abs(x_reg));
130130
if (x_reg.size() == 0 || max < eps)
131131
return ::testing::AssertionSuccess();
132132
else

c++/nda/h5.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ namespace nda {
5151
// Given an array/view, prepare and return the corresponding h5::array_view to be written/read into.
5252
template <MemoryArray A>
5353
auto prepare_h5_array_view(const A &a) {
54-
auto [parent_shape, h5_strides] = h5::array_interface::get_parent_shape_and_h5_strides(a.indexmap().strides().data(), A::rank, a.shape().data());
54+
auto [parent_shape, h5_strides] =
55+
h5::array_interface::get_parent_shape_and_h5_strides(a.indexmap().strides().data(), A::rank, a.shape().data());
5556
auto v = h5::array_interface::array_view{h5::hdf5_type<get_value_t<A>>(), (void *)a.data(), A::rank, is_complex_v<typename A::value_type>};
5657
for (int u = 0; u < A::rank; ++u) {
5758
v.slab.count[u] = a.shape()[u];

0 commit comments

Comments
 (0)