Skip to content

Commit 6bcc32f

Browse files
committed
UMTensor: no need to use value_type, use T directly + format
1 parent 720f678 commit 6bcc32f

File tree

1 file changed

+49
-69
lines changed

1 file changed

+49
-69
lines changed

src/TiledArray/device/um_tensor.h

Lines changed: 49 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ namespace detail {
5151
template <typename T>
5252
void to_device(const UMTensor<T> &tensor) {
5353
auto stream = device::stream_for(tensor.range());
54-
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
55-
tensor, stream);
54+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
55+
stream);
5656
}
5757

5858
/// pre-fetch to device (non-const)
5959
template <typename T>
6060
void to_device(UMTensor<T> &tensor) {
6161
auto stream = device::stream_for(tensor.range());
62-
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
63-
tensor, stream);
62+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
63+
stream);
6464
}
6565

6666
/// get device data pointer
@@ -82,7 +82,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
8282
Queue &queue) {
8383
if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
8484
std::is_arithmetic_v<Scalar>) {
85-
blas::scal(size, factor, data, 1, queue);
85+
blas::scal(size, factor, data, T(1), queue);
8686
} else {
8787
if constexpr (TiledArray::detail::is_complex_v<T>) {
8888
abort(); // fused conjugation requires custom kernels, not yet supported
@@ -93,7 +93,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
9393
Scalar,
9494
TiledArray::detail::ComplexConjugate<
9595
TiledArray::detail::ComplexNegTag>>) {
96-
blas::scal(size, static_cast<T>(-1), data, 1, queue);
96+
blas::scal(size, T(-1), data, T(1), queue);
9797
}
9898
}
9999
}
@@ -148,9 +148,8 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
148148
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k));
149149
const integer ldc = std::max(integer{1}, n);
150150

151-
using value_type = UMTensor<T>::value_type;
152-
value_type factor_t = value_type(factor);
153-
value_type zero(0);
151+
auto factor_t = T(factor);
152+
T zero(0);
154153

155154
blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(),
156155
gemm_helper.left_op(), n, m, k, factor_t,
@@ -210,9 +209,8 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
210209
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k));
211210
const integer ldc = std::max(integer{1}, n);
212211

213-
using value_type = UMTensor<T>::value_type;
214-
value_type factor_t = value_type(factor);
215-
value_type one(1);
212+
auto factor_t = T(factor);
213+
T one(1);
216214

217215
blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(),
218216
gemm_helper.left_op(), n, m, k, factor_t,
@@ -238,8 +236,8 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
238236

239237
// copy data
240238
auto &queue = blasqueue_for(result.range());
241-
blas::copy(result.size(), detail::device_data(arg), 1,
242-
detail::device_data(result), 1, queue);
239+
blas::copy(result.size(), detail::device_data(arg), T(1),
240+
detail::device_data(result), T(1), queue);
243241
device::sync_madness_task_with(stream);
244242
return result;
245243
}
@@ -266,19 +264,14 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
266264
detail::to_device(result);
267265

268266
// copy data
269-
blas::copy(result.size(), detail::device_data(arg), 1,
270-
detail::device_data(result), 1, queue);
267+
blas::copy(result.size(), detail::device_data(arg), T(1),
268+
detail::device_data(result), T(1), queue);
271269
device::sync_madness_task_with(stream);
272270
return result;
273271
}
274272

275273
template <typename T, typename Index>
276274
UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
277-
// although shift_to is currently fine on shared objects since ranges are
278-
// not shared, this will change in the future
279-
#ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED
280-
TA_ASSERT(data_.use_count() <= 1);
281-
#endif
282275
const_cast<TiledArray::Range &>(arg.range()).inplace_shift(bound_shift);
283276
return arg;
284277
}
@@ -303,8 +296,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
303296
detail::to_device(result);
304297

305298
// invoke permute function from librett
306-
using value_type = UMTensor<T>::value_type;
307-
librett_permute(const_cast<value_type *>(detail::device_data(arg)),
299+
librett_permute(const_cast<T *>(detail::device_data(arg)),
308300
detail::device_data(result), arg.range(), perm, stream);
309301
device::sync_madness_task_with(stream);
310302
return result;
@@ -369,8 +361,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
369361

370362
template <typename T>
371363
UMTensor<T> neg(const UMTensor<T> &arg) {
372-
using value_type = UMTensor<T>::value_type;
373-
return scale(arg, value_type(-1.0));
364+
return scale(arg, T(-1.0));
374365
}
375366

376367
template <typename T, typename Perm>
@@ -382,8 +373,7 @@ UMTensor<T> neg(const UMTensor<T> &arg, const Perm &perm) {
382373

383374
template <typename T>
384375
UMTensor<T> &neg_to(UMTensor<T> &arg) {
385-
using value_type = UMTensor<T>::value_type;
386-
return scale_to(arg, value_type(-1.0));
376+
return scale_to(arg, T(-1.0));
387377
}
388378

389379
///
@@ -402,11 +392,10 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
402392
detail::to_device(result);
403393

404394
// result = arg1 + arg2
405-
using value_type = typename UMTensor<T>::value_type;
406-
blas::copy(result.size(), detail::device_data(arg1), 1,
407-
detail::device_data(result), 1, queue);
408-
blas::axpy(result.size(), value_type(1), detail::device_data(arg2), 1,
409-
detail::device_data(result), 1, queue);
395+
blas::copy(result.size(), detail::device_data(arg1), T(1),
396+
detail::device_data(result), T(1), queue);
397+
blas::axpy(result.size(), T(1), detail::device_data(arg2), T(1),
398+
detail::device_data(result), T(1), queue);
410399
device::sync_madness_task_with(stream);
411400
return result;
412401
}
@@ -449,9 +438,8 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
449438
detail::to_device(arg);
450439

451440
// result += arg
452-
using value_type = typename UMTensor<T>::value_type;
453-
blas::axpy(result.size(), value_type(1), detail::device_data(arg), 1,
454-
detail::device_data(result), 1, queue);
441+
blas::axpy(result.size(), T(1), detail::device_data(arg), T(1),
442+
detail::device_data(result), T(1), queue);
455443
device::sync_madness_task_with(stream);
456444
return result;
457445
}
@@ -480,11 +468,10 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
480468
detail::to_device(result);
481469

482470
// result = arg1 - arg2
483-
using value_type = typename UMTensor<T>::value_type;
484-
blas::copy(result.size(), detail::device_data(arg1), 1,
485-
detail::device_data(result), 1, queue);
486-
blas::axpy(result.size(), value_type(-1), detail::device_data(arg2), 1,
487-
detail::device_data(result), 1, queue);
471+
blas::copy(result.size(), detail::device_data(arg1), T(1),
472+
detail::device_data(result), T(1), queue);
473+
blas::axpy(result.size(), T(-1), detail::device_data(arg2), T(1),
474+
detail::device_data(result), T(1), queue);
488475
device::sync_madness_task_with(stream);
489476
return result;
490477
}
@@ -527,9 +514,8 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
527514
detail::to_device(arg);
528515

529516
// result -= arg
530-
using value_type = typename UMTensor<T>::value_type;
531-
blas::axpy(result.size(), value_type(-1), detail::device_data(arg), 1,
532-
detail::device_data(result), 1, queue);
517+
blas::axpy(result.size(), T(-1), detail::device_data(arg), T(1),
518+
detail::device_data(result), T(1), queue);
533519
device::sync_madness_task_with(stream);
534520
return result;
535521
}
@@ -548,12 +534,10 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
548534

549535
template <typename T>
550536
UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
551-
std::size_t n = arg1.size();
552-
TA_ASSERT(arg2.size() == n);
537+
TA_ASSERT(arg1.size() == arg2.size());
553538

554539
auto stream = device::stream_for(arg1.range());
555540

556-
using value_type = typename UMTensor<T>::value_type;
557541
UMTensor<T> result(arg1.range());
558542

559543
detail::to_device(arg1);
@@ -562,7 +546,7 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
562546

563547
// element-wise multiplication
564548
device::mult_kernel(detail::device_data(result), detail::device_data(arg1),
565-
detail::device_data(arg2), n, stream);
549+
detail::device_data(arg2), arg1.size(), stream);
566550
device::sync_madness_task_with(stream);
567551
return result;
568552
}
@@ -599,16 +583,14 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
599583
template <typename T>
600584
UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg) {
601585
auto stream = device::stream_for(result.range());
602-
603-
std::size_t n = result.size();
604-
TA_ASSERT(n == arg.size());
586+
TA_ASSERT(result.size() == arg.size());
605587

606588
detail::to_device(result);
607589
detail::to_device(arg);
608590

609591
// in-place element-wise multiplication
610592
device::mult_to_kernel(detail::device_data(result), detail::device_data(arg),
611-
n, stream);
593+
result.size(), stream);
612594

613595
device::sync_madness_task_with(stream);
614596
return result;
@@ -627,19 +609,17 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg,
627609
///
628610

629611
template <typename T>
630-
typename UMTensor<T>::value_type dot(const UMTensor<T> &arg1,
631-
const UMTensor<T> &arg2) {
612+
T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
632613
auto &queue = blasqueue_for(arg1.range());
633614
const auto stream = device::Stream(queue.device(), queue.stream());
634615

635616
detail::to_device(arg1);
636617
detail::to_device(arg2);
637618

638619
// compute dot product using device BLAS
639-
using value_type = typename UMTensor<T>::value_type;
640-
value_type result = value_type(0);
641-
blas::dot(arg1.size(), detail::device_data(arg1), 1,
642-
detail::device_data(arg2), 1, &result, queue);
620+
auto result = T(0);
621+
blas::dot(arg1.size(), detail::device_data(arg1), T(1),
622+
detail::device_data(arg2), T(1), &result, queue);
643623
device::sync_madness_task_with(stream);
644624
return result;
645625
}
@@ -649,28 +629,27 @@ typename UMTensor<T>::value_type dot(const UMTensor<T> &arg1,
649629
///
650630

651631
template <typename T>
652-
typename UMTensor<T>::value_type squared_norm(const UMTensor<T> &arg) {
632+
T squared_norm(const UMTensor<T> &arg) {
653633
auto &queue = blasqueue_for(arg.range());
654634
const auto stream = device::Stream(queue.device(), queue.stream());
655635

656636
detail::to_device(arg);
657637

658638
// compute squared norm using dot
659-
using value_type = typename UMTensor<T>::value_type;
660-
value_type result = value_type(0);
661-
blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg),
662-
1, &result, queue);
639+
auto result = T(0);
640+
blas::dot(arg.size(), detail::device_data(arg), T(1),
641+
detail::device_data(arg), T(1), &result, queue);
663642
device::sync_madness_task_with(stream);
664643
return result;
665644
}
666645

667646
template <typename T>
668-
typename UMTensor<T>::value_type norm(const UMTensor<T> &arg) {
647+
T norm(const UMTensor<T> &arg) {
669648
return std::sqrt(squared_norm(arg));
670649
}
671650

672651
template <typename T>
673-
typename UMTensor<T>::value_type sum(const UMTensor<T> &arg) {
652+
T sum(const UMTensor<T> &arg) {
674653
detail::to_device(arg);
675654
auto stream = device::stream_for(arg.range());
676655
auto result =
@@ -680,7 +659,7 @@ typename UMTensor<T>::value_type sum(const UMTensor<T> &arg) {
680659
}
681660

682661
template <typename T>
683-
typename UMTensor<T>::value_type product(const UMTensor<T> &arg) {
662+
T product(const UMTensor<T> &arg) {
684663
detail::to_device(arg);
685664
auto stream = device::stream_for(arg.range());
686665
auto result =
@@ -690,7 +669,7 @@ typename UMTensor<T>::value_type product(const UMTensor<T> &arg) {
690669
}
691670

692671
template <typename T>
693-
typename UMTensor<T>::value_type max(const UMTensor<T> &arg) {
672+
T max(const UMTensor<T> &arg) {
694673
detail::to_device(arg);
695674
auto stream = device::stream_for(arg.range());
696675
auto result =
@@ -700,7 +679,7 @@ typename UMTensor<T>::value_type max(const UMTensor<T> &arg) {
700679
}
701680

702681
template <typename T>
703-
typename UMTensor<T>::value_type min(const UMTensor<T> &arg) {
682+
T min(const UMTensor<T> &arg) {
704683
detail::to_device(arg);
705684
auto stream = device::stream_for(arg.range());
706685
auto result =
@@ -710,7 +689,7 @@ typename UMTensor<T>::value_type min(const UMTensor<T> &arg) {
710689
}
711690

712691
template <typename T>
713-
typename UMTensor<T>::value_type abs_max(const UMTensor<T> &arg) {
692+
T abs_max(const UMTensor<T> &arg) {
714693
detail::to_device(arg);
715694
auto stream = device::stream_for(arg.range());
716695
auto result =
@@ -720,12 +699,13 @@ typename UMTensor<T>::value_type abs_max(const UMTensor<T> &arg) {
720699
}
721700

722701
template <typename T>
723-
typename UMTensor<T>::value_type abs_min(const UMTensor<T> &arg) {
702+
T abs_min(const UMTensor<T> &arg) {
724703
detail::to_device(arg);
725704
auto stream = device::stream_for(arg.range());
726705
auto result =
727706
device::absmin_kernel(detail::device_data(arg), arg.size(), stream);
728707
device::sync_madness_task_with(stream);
708+
729709
return result;
730710
}
731711

0 commit comments

Comments
 (0)