@@ -525,19 +525,39 @@ void assign_from_ndarray(RHS const &rhs) { // FIXME noexcept {
525525template <typename Scalar>
526526void fill_with_scalar (Scalar const &scalar) noexcept {
527527 // we make a special implementation if the array is strided in 1d or contiguous
528- if constexpr (has_layout_strided_1d<self_t >) {
529- const long L = indexmap ().size ();
530- auto *__restrict const p = data (); // no alias possible here!
531- if constexpr (has_contiguous_layout<self_t >) {
532- for (long i = 0 ; i < L; ++i) p[i] = scalar;
528+ if constexpr (mem::on_host<self_t >) {
529+ if constexpr (has_layout_strided_1d<self_t >) {
530+ const long L = size ();
531+ auto *__restrict const p = data (); // no alias possible here!
532+ if constexpr (has_contiguous_layout<self_t >) {
533+ for (long i = 0 ; i < L; ++i) p[i] = scalar;
534+ } else {
535+ const long stri = indexmap ().min_stride ();
536+ const long Lstri = L * stri;
537+ for (long i = 0 ; i != Lstri; i += stri) p[i] = scalar;
538+ }
533539 } else {
534- const long stri = indexmap ().min_stride ();
535- const long Lstri = L * stri;
536- for (long i = 0 ; i != Lstri; i += stri) p[i] = scalar;
540+ for (auto &x : *this ) x = scalar;
541+ }
542+ } else if constexpr (mem::on_device<self_t > or mem::on_unified<self_t >) { // on device
543+ if constexpr (has_layout_strided_1d<self_t >) { // possibly contiguous
544+ if constexpr (has_contiguous_layout<self_t >) {
545+ mem::fill_n<mem::get_addr_space<self_t >>(data (), size (), value_type (scalar));
546+ } else {
547+ const long stri = indexmap ().min_stride ();
548+ mem::fill2D_n<mem::get_addr_space<self_t >>(data (), stri, 1 , size (), value_type (scalar));
549+ }
550+ } else {
551+ // check for 2D layout
552+ auto bl_layout = get_block_layout (*this );
553+ if (bl_layout) {
554+ auto [n_bl, bl_size, bl_str] = *bl_layout;
555+ mem::fill2D_n<mem::get_addr_space<self_t >>(data (), bl_str, bl_size, n_bl, value_type (scalar));
556+ } else {
557+ // MAM: implement recursive call to fill_with_scalar on (i,nda::ellipsis{})
558+ NDA_RUNTIME_ERROR << " fill_with_scalar: Not implemented yet for general layout. " ;
559+ }
537560 }
538- } else {
539- // no compile-time memory layout guarantees
540- for (auto &x : *this ) x = scalar;
541561 }
542562}
543563
@@ -556,7 +576,6 @@ void assign_from_scalar(Scalar const &scalar) noexcept {
556576 fill_with_scalar (0 );
557577 else
558578 fill_with_scalar (Scalar{0 * scalar}); // FIXME : improve this
559- const long imax = std::min (extent (0 ), extent (1 ));
560- for (long i = 0 ; i < imax; ++i) operator ()(i, i) = scalar;
579+ diagonal (*this ).fill_with_scalar (scalar);
561580 }
562581}
0 commit comments