Skip to content

Commit 60d4abd

Browse files
committed
UMTensor: introduce const and non-const versions of to_device
1 parent e9f7d95 commit 60d4abd

File tree

1 file changed

+54
-30
lines changed

1 file changed

+54
-30
lines changed

src/TiledArray/device/um_tensor.h

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,23 @@
4444

4545
#include <concepts>
4646

47-
4847
namespace TiledArray {
4948
namespace detail {
5049

50+
/// pre-fetch to device
5151
template <typename T>
5252
void to_device(const UMTensor<T> &tensor) {
5353
auto stream = device::stream_for(tensor.range());
5454
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
55-
const_cast<UMTensor<T> &>(tensor), stream);
55+
tensor, stream);
56+
}
57+
58+
/// pre-fetch to device (non-const)
59+
template <typename T>
60+
void to_device(UMTensor<T> &tensor) {
61+
auto stream = device::stream_for(tensor.range());
62+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
63+
tensor, stream);
5664
}
5765

5866
/// get device data pointer
@@ -70,7 +78,8 @@ auto *device_data(UMTensor<T> &tensor) {
7078
/// handle ComplexConjugate handling for scaling functions
7179
/// follows the logic in device/btas.h
7280
template <typename T, typename Scalar, typename Queue>
73-
void apply_scale_factor(T* data, std::size_t size, const Scalar& factor, Queue& queue) {
81+
void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
82+
Queue &queue) {
7483
if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
7584
std::is_arithmetic_v<Scalar>) {
7685
blas::scal(size, factor, data, 1, queue);
@@ -98,8 +107,9 @@ void apply_scale_factor(T* data, std::size_t size, const Scalar& factor, Queue&
98107

99108
template <typename T, typename Scalar>
100109
requires TiledArray::detail::is_numeric_v<Scalar>
101-
UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right, Scalar factor,
102-
const TiledArray::math::GemmHelper &gemm_helper) {
110+
UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
111+
Scalar factor,
112+
const TiledArray::math::GemmHelper &gemm_helper) {
103113
// Check that the arguments are not empty and have the correct ranks
104114
TA_ASSERT(!left.empty());
105115
TA_ASSERT(left.range().rank() == gemm_helper.left_rank());
@@ -153,8 +163,9 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right, Scalar facto
153163

154164
template <typename T, typename Scalar>
155165
requires TiledArray::detail::is_numeric_v<Scalar>
156-
void gemm(UMTensor<T> &result, const UMTensor<T> &left, const UMTensor<T> &right,
157-
Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) {
166+
void gemm(UMTensor<T> &result, const UMTensor<T> &left,
167+
const UMTensor<T> &right, Scalar factor,
168+
const TiledArray::math::GemmHelper &gemm_helper) {
158169
// Check that the result is not empty and has the correct rank
159170
TA_ASSERT(!result.empty());
160171
TA_ASSERT(result.range().rank() == gemm_helper.result_rank());
@@ -262,7 +273,7 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
262273
}
263274

264275
template <typename T, typename Index>
265-
UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
276+
UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
266277
// although shift_to is currently fine on shared objects since ranges are
267278
// not shared, this will change in the future
268279
#ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED
@@ -277,7 +288,8 @@ UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
277288
///
278289

279290
template <typename T>
280-
UMTensor<T> permute(const UMTensor<T> &arg, const TiledArray::Permutation &perm) {
291+
UMTensor<T> permute(const UMTensor<T> &arg,
292+
const TiledArray::Permutation &perm) {
281293
TA_ASSERT(!arg.empty());
282294
TA_ASSERT(perm.size() == arg.range().rank());
283295

@@ -300,7 +312,7 @@ UMTensor<T> permute(const UMTensor<T> &arg, const TiledArray::Permutation &perm
300312

301313
template <typename T>
302314
UMTensor<T> permute(const UMTensor<T> &arg,
303-
const TiledArray::BipartitePermutation &perm) {
315+
const TiledArray::BipartitePermutation &perm) {
304316
TA_ASSERT(!arg.empty());
305317
TA_ASSERT(inner_size(perm) == 0); // this must be a plain permutation
306318
return permute(arg, outer(perm));
@@ -313,13 +325,13 @@ UMTensor<T> permute(const UMTensor<T> &arg,
313325
template <typename T, typename Scalar>
314326
requires TiledArray::detail::is_numeric_v<Scalar>
315327
UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
316-
317328
auto &queue = blasqueue_for(arg.range());
318329
const auto stream = device::Stream(queue.device(), queue.stream());
319330

320331
auto result = clone(arg);
321332

322-
detail::apply_scale_factor(detail::device_data(result), result.size(), factor, queue);
333+
detail::apply_scale_factor(detail::device_data(result), result.size(), factor,
334+
queue);
323335

324336
device::sync_madness_task_with(stream);
325337
return result;
@@ -335,7 +347,8 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
335347

336348
// in-place scale
337349
// ComplexConjugate is handled as in device/btas.h
338-
detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor, queue);
350+
detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor,
351+
queue);
339352

340353
device::sync_madness_task_with(stream);
341354
return arg;
@@ -344,7 +357,8 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
344357
template <typename T, typename Scalar, typename Perm>
345358
requires TiledArray::detail::is_numeric_v<Scalar> &&
346359
TiledArray::detail::is_permutation_v<Perm>
347-
UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor, const Perm &perm) {
360+
UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
361+
const Perm &perm) {
348362
auto result = scale(arg, factor);
349363
return permute(result, perm);
350364
}
@@ -399,23 +413,25 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
399413

400414
template <typename T, typename Scalar>
401415
requires TiledArray::detail::is_numeric_v<Scalar>
402-
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor) {
416+
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
417+
const Scalar factor) {
403418
auto result = add(arg1, arg2);
404419
return scale_to(result, factor);
405420
}
406421

407422
template <typename T, typename Perm>
408423
requires TiledArray::detail::is_permutation_v<Perm>
409-
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Perm &perm) {
424+
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
425+
const Perm &perm) {
410426
auto result = add(arg1, arg2);
411427
return permute(result, perm);
412428
}
413429

414430
template <typename T, typename Scalar, typename Perm>
415431
requires TiledArray::detail::is_numeric_v<Scalar> &&
416432
TiledArray::detail::is_permutation_v<Perm>
417-
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor,
418-
const Perm &perm) {
433+
UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
434+
const Scalar factor, const Perm &perm) {
419435
auto result = add(arg1, arg2, factor);
420436
return permute(result, perm);
421437
}
@@ -442,7 +458,8 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
442458

443459
template <typename T, typename Scalar>
444460
requires TiledArray::detail::is_numeric_v<Scalar>
445-
UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg, const Scalar factor) {
461+
UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg,
462+
const Scalar factor) {
446463
add_to(result, arg);
447464
return scale_to(result, factor);
448465
}
@@ -474,23 +491,25 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
474491

475492
template <typename T, typename Scalar>
476493
requires TiledArray::detail::is_numeric_v<Scalar>
477-
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor) {
494+
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
495+
const Scalar factor) {
478496
auto result = subt(arg1, arg2);
479497
return scale_to(result, factor);
480498
}
481499

482500
template <typename T, typename Perm>
483501
requires TiledArray::detail::is_permutation_v<Perm>
484-
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Perm &perm) {
502+
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
503+
const Perm &perm) {
485504
auto result = subt(arg1, arg2);
486505
return permute(result, perm);
487506
}
488507

489508
template <typename T, typename Scalar, typename Perm>
490509
requires TiledArray::detail::is_numeric_v<Scalar> &&
491510
TiledArray::detail::is_permutation_v<Perm>
492-
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor,
493-
const Perm &perm) {
511+
UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
512+
const Scalar factor, const Perm &perm) {
494513
auto result = subt(arg1, arg2, factor);
495514
return permute(result, perm);
496515
}
@@ -517,7 +536,8 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
517536

518537
template <typename T, typename Scalar>
519538
requires TiledArray::detail::is_numeric_v<Scalar>
520-
UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg, const Scalar factor) {
539+
UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
540+
const Scalar factor) {
521541
subt_to(result, arg);
522542
return scale_to(result, factor);
523543
}
@@ -549,23 +569,25 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
549569

550570
template <typename T, typename Scalar>
551571
requires TiledArray::detail::is_numeric_v<Scalar>
552-
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor) {
572+
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
573+
const Scalar factor) {
553574
auto result = mult(arg1, arg2);
554575
return scale_to(result, factor);
555576
}
556577

557578
template <typename T, typename Perm>
558579
requires TiledArray::detail::is_permutation_v<Perm>
559-
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Perm &perm) {
580+
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
581+
const Perm &perm) {
560582
auto result = mult(arg1, arg2);
561583
return permute(result, perm);
562584
}
563585

564586
template <typename T, typename Scalar, typename Perm>
565587
requires TiledArray::detail::is_numeric_v<Scalar> &&
566588
TiledArray::detail::is_permutation_v<Perm>
567-
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2, const Scalar factor,
568-
const Perm &perm) {
589+
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
590+
const Scalar factor, const Perm &perm) {
569591
auto result = mult(arg1, arg2, factor);
570592
return permute(result, perm);
571593
}
@@ -594,7 +616,8 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg) {
594616

595617
template <typename T, typename Scalar>
596618
requires TiledArray::detail::is_numeric_v<Scalar>
597-
UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg, const Scalar factor) {
619+
UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg,
620+
const Scalar factor) {
598621
mult_to(result, arg);
599622
return scale_to(result, factor);
600623
}
@@ -604,7 +627,8 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg, const Scalar f
604627
///
605628

606629
template <typename T>
607-
typename UMTensor<T>::value_type dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
630+
typename UMTensor<T>::value_type dot(const UMTensor<T> &arg1,
631+
const UMTensor<T> &arg2) {
608632
auto &queue = blasqueue_for(arg1.range());
609633
const auto stream = device::Stream(queue.device(), queue.stream());
610634

0 commit comments

Comments
 (0)