4444
4545#include < concepts>
4646
47-
4847namespace TiledArray {
4948namespace detail {
5049
50+ // / pre-fetch to device
5151template <typename T>
5252void to_device (const UMTensor<T> &tensor) {
5353 auto stream = device::stream_for (tensor.range ());
5454 TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
55- const_cast <UMTensor<T> &>(tensor), stream);
55+ tensor, stream);
56+ }
57+
58+ // / pre-fetch to device (non-const)
59+ template <typename T>
60+ void to_device (UMTensor<T> &tensor) {
61+ auto stream = device::stream_for (tensor.range ());
62+ TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
63+ tensor, stream);
5664}
5765
5866// / get device data pointer
@@ -70,7 +78,8 @@ auto *device_data(UMTensor<T> &tensor) {
7078// / handle ComplexConjugate handling for scaling functions
7179// / follows the logic in device/btas.h
7280template <typename T, typename Scalar, typename Queue>
73- void apply_scale_factor (T* data, std::size_t size, const Scalar& factor, Queue& queue) {
81+ void apply_scale_factor (T *data, std::size_t size, const Scalar &factor,
82+ Queue &queue) {
7483 if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
7584 std::is_arithmetic_v<Scalar>) {
7685 blas::scal (size, factor, data, 1 , queue);
@@ -98,8 +107,9 @@ void apply_scale_factor(T* data, std::size_t size, const Scalar& factor, Queue&
98107
99108template <typename T, typename Scalar>
100109 requires TiledArray::detail::is_numeric_v<Scalar>
101- UMTensor<T> gemm (const UMTensor<T> &left, const UMTensor<T> &right, Scalar factor,
102- const TiledArray::math::GemmHelper &gemm_helper) {
110+ UMTensor<T> gemm (const UMTensor<T> &left, const UMTensor<T> &right,
111+ Scalar factor,
112+ const TiledArray::math::GemmHelper &gemm_helper) {
103113 // Check that the arguments are not empty and have the correct ranks
104114 TA_ASSERT (!left.empty ());
105115 TA_ASSERT (left.range ().rank () == gemm_helper.left_rank ());
@@ -153,8 +163,9 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right, Scalar facto
153163
154164template <typename T, typename Scalar>
155165 requires TiledArray::detail::is_numeric_v<Scalar>
156- void gemm (UMTensor<T> &result, const UMTensor<T> &left, const UMTensor<T> &right,
157- Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) {
166+ void gemm (UMTensor<T> &result, const UMTensor<T> &left,
167+ const UMTensor<T> &right, Scalar factor,
168+ const TiledArray::math::GemmHelper &gemm_helper) {
158169 // Check that the result is not empty and has the correct rank
159170 TA_ASSERT (!result.empty ());
160171 TA_ASSERT (result.range ().rank () == gemm_helper.result_rank ());
@@ -262,7 +273,7 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
262273}
263274
264275template <typename T, typename Index>
265- UMTensor<T> &shift_to (UMTensor<T> &arg, const Index &bound_shift) {
276+ UMTensor<T> &shift_to (UMTensor<T> &arg, const Index &bound_shift) {
266277 // although shift_to is currently fine on shared objects since ranges are
267278 // not shared, this will change in the future
268279#ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED
@@ -277,7 +288,8 @@ UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
277288// /
278289
279290template <typename T>
280- UMTensor<T> permute (const UMTensor<T> &arg, const TiledArray::Permutation &perm) {
291+ UMTensor<T> permute (const UMTensor<T> &arg,
292+ const TiledArray::Permutation &perm) {
281293 TA_ASSERT (!arg.empty ());
282294 TA_ASSERT (perm.size () == arg.range ().rank ());
283295
@@ -300,7 +312,7 @@ UMTensor<T> permute(const UMTensor<T> &arg, const TiledArray::Permutation &perm
300312
301313template <typename T>
302314UMTensor<T> permute (const UMTensor<T> &arg,
303- const TiledArray::BipartitePermutation &perm) {
315+ const TiledArray::BipartitePermutation &perm) {
304316 TA_ASSERT (!arg.empty ());
305317 TA_ASSERT (inner_size (perm) == 0 ); // this must be a plain permutation
306318 return permute (arg, outer (perm));
@@ -313,13 +325,13 @@ UMTensor<T> permute(const UMTensor<T> &arg,
313325template <typename T, typename Scalar>
314326 requires TiledArray::detail::is_numeric_v<Scalar>
315327UMTensor<T> scale (const UMTensor<T> &arg, const Scalar factor) {
316-
317328 auto &queue = blasqueue_for (arg.range ());
318329 const auto stream = device::Stream (queue.device (), queue.stream ());
319330
320331 auto result = clone (arg);
321332
322- detail::apply_scale_factor (detail::device_data (result), result.size (), factor, queue);
333+ detail::apply_scale_factor (detail::device_data (result), result.size (), factor,
334+ queue);
323335
324336 device::sync_madness_task_with (stream);
325337 return result;
@@ -335,7 +347,8 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
335347
336348 // in-place scale
337349 // ComplexConjugate is handled as in device/btas.h
338- detail::apply_scale_factor (detail::device_data (arg), arg.size (), factor, queue);
350+ detail::apply_scale_factor (detail::device_data (arg), arg.size (), factor,
351+ queue);
339352
340353 device::sync_madness_task_with (stream);
341354 return arg;
@@ -344,7 +357,8 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
344357template <typename T, typename Scalar, typename Perm>
345358 requires TiledArray::detail::is_numeric_v<Scalar> &&
346359 TiledArray::detail::is_permutation_v<Perm>
347- UMTensor<T> scale (const UMTensor<T> &arg, const Scalar factor, const Perm &perm) {
360+ UMTensor<T> scale (const UMTensor<T> &arg, const Scalar factor,
361+ const Perm &perm) {
348362 auto result = scale (arg, factor);
349363 return permute (result, perm);
350364}
@@ -399,23 +413,25 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
399413
400414template <typename T, typename Scalar>
401415 requires TiledArray::detail::is_numeric_v<Scalar>
402- UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor) {
416+ UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
417+ const Scalar factor) {
403418 auto result = add (arg1, arg2);
404419 return scale_to (result, factor);
405420}
406421
407422template <typename T, typename Perm>
408423 requires TiledArray::detail::is_permutation_v<Perm>
409- UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Perm &perm) {
424+ UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
425+ const Perm &perm) {
410426 auto result = add (arg1, arg2);
411427 return permute (result, perm);
412428}
413429
414430template <typename T, typename Scalar, typename Perm>
415431 requires TiledArray::detail::is_numeric_v<Scalar> &&
416432 TiledArray::detail::is_permutation_v<Perm>
417- UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor,
418- const Perm &perm) {
433+ UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
434+ const Scalar factor, const Perm &perm) {
419435 auto result = add (arg1, arg2, factor);
420436 return permute (result, perm);
421437}
@@ -442,7 +458,8 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
442458
443459template <typename T, typename Scalar>
444460 requires TiledArray::detail::is_numeric_v<Scalar>
445- UMTensor<T> &add_to (UMTensor<T> &result, const UMTensor<T> &arg, const Scalar factor) {
461+ UMTensor<T> &add_to (UMTensor<T> &result, const UMTensor<T> &arg,
462+ const Scalar factor) {
446463 add_to (result, arg);
447464 return scale_to (result, factor);
448465}
@@ -474,23 +491,25 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
474491
475492template <typename T, typename Scalar>
476493 requires TiledArray::detail::is_numeric_v<Scalar>
477- UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor) {
494+ UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
495+ const Scalar factor) {
478496 auto result = subt (arg1, arg2);
479497 return scale_to (result, factor);
480498}
481499
482500template <typename T, typename Perm>
483501 requires TiledArray::detail::is_permutation_v<Perm>
484- UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Perm &perm) {
502+ UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
503+ const Perm &perm) {
485504 auto result = subt (arg1, arg2);
486505 return permute (result, perm);
487506}
488507
489508template <typename T, typename Scalar, typename Perm>
490509 requires TiledArray::detail::is_numeric_v<Scalar> &&
491510 TiledArray::detail::is_permutation_v<Perm>
492- UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor,
493- const Perm &perm) {
511+ UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
512+ const Scalar factor, const Perm &perm) {
494513 auto result = subt (arg1, arg2, factor);
495514 return permute (result, perm);
496515}
@@ -517,7 +536,8 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
517536
518537template <typename T, typename Scalar>
519538 requires TiledArray::detail::is_numeric_v<Scalar>
520- UMTensor<T> &subt_to (UMTensor<T> &result, const UMTensor<T> &arg, const Scalar factor) {
539+ UMTensor<T> &subt_to (UMTensor<T> &result, const UMTensor<T> &arg,
540+ const Scalar factor) {
521541 subt_to (result, arg);
522542 return scale_to (result, factor);
523543}
@@ -549,23 +569,25 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
549569
550570template <typename T, typename Scalar>
551571 requires TiledArray::detail::is_numeric_v<Scalar>
552- UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor) {
572+ UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
573+ const Scalar factor) {
553574 auto result = mult (arg1, arg2);
554575 return scale_to (result, factor);
555576}
556577
557578template <typename T, typename Perm>
558579 requires TiledArray::detail::is_permutation_v<Perm>
559- UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Perm &perm) {
580+ UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
581+ const Perm &perm) {
560582 auto result = mult (arg1, arg2);
561583 return permute (result, perm);
562584}
563585
564586template <typename T, typename Scalar, typename Perm>
565587 requires TiledArray::detail::is_numeric_v<Scalar> &&
566588 TiledArray::detail::is_permutation_v<Perm>
567- UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor,
568- const Perm &perm) {
589+ UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
590+ const Scalar factor, const Perm &perm) {
569591 auto result = mult (arg1, arg2, factor);
570592 return permute (result, perm);
571593}
@@ -594,7 +616,8 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg) {
594616
595617template <typename T, typename Scalar>
596618 requires TiledArray::detail::is_numeric_v<Scalar>
597- UMTensor<T> &mult_to (UMTensor<T> &result, const UMTensor<T> &arg, const Scalar factor) {
619+ UMTensor<T> &mult_to (UMTensor<T> &result, const UMTensor<T> &arg,
620+ const Scalar factor) {
598621 mult_to (result, arg);
599622 return scale_to (result, factor);
600623}
@@ -604,7 +627,8 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg, const Scalar f
604627// /
605628
606629template <typename T>
607- typename UMTensor<T>::value_type dot (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
630+ typename UMTensor<T>::value_type dot (const UMTensor<T> &arg1,
631+ const UMTensor<T> &arg2) {
608632 auto &queue = blasqueue_for (arg1.range ());
609633 const auto stream = device::Stream (queue.device (), queue.stream ());
610634
0 commit comments