@@ -51,16 +51,16 @@ namespace detail {
5151template <typename T>
5252void to_device (const UMTensor<T> &tensor) {
5353 auto stream = device::stream_for (tensor.range ());
54- TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
55- tensor, stream);
54+ TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
55+ stream);
5656}
5757
5858// / pre-fetch to device (non-const)
5959template <typename T>
6060void to_device (UMTensor<T> &tensor) {
6161 auto stream = device::stream_for (tensor.range ());
62- TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
63- tensor, stream);
62+ TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
63+ stream);
6464}
6565
6666// / get device data pointer
@@ -82,7 +82,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
8282 Queue &queue) {
8383 if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
8484 std::is_arithmetic_v<Scalar>) {
85- blas::scal (size, factor, data, 1 , queue);
85+ blas::scal (size, factor, data, T ( 1 ) , queue);
8686 } else {
8787 if constexpr (TiledArray::detail::is_complex_v<T>) {
8888 abort (); // fused conjugation requires custom kernels, not yet supported
@@ -93,7 +93,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
9393 Scalar,
9494 TiledArray::detail::ComplexConjugate<
9595 TiledArray::detail::ComplexNegTag>>) {
96- blas::scal (size, static_cast <T> (-1 ), data, 1 , queue);
96+ blas::scal (size, T (-1 ), data, T ( 1 ) , queue);
9797 }
9898 }
9999 }
@@ -148,9 +148,8 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
148148 (gemm_helper.right_op () == TiledArray::math::blas::Op::NoTrans ? n : k));
149149 const integer ldc = std::max (integer{1 }, n);
150150
151- using value_type = UMTensor<T>::value_type;
152- value_type factor_t = value_type (factor);
153- value_type zero (0 );
151+ auto factor_t = T (factor);
152+ T zero (0 );
154153
155154 blas::gemm (blas::Layout::ColMajor, gemm_helper.right_op (),
156155 gemm_helper.left_op (), n, m, k, factor_t ,
@@ -210,9 +209,8 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
210209 (gemm_helper.right_op () == TiledArray::math::blas::Op::NoTrans ? n : k));
211210 const integer ldc = std::max (integer{1 }, n);
212211
213- using value_type = UMTensor<T>::value_type;
214- value_type factor_t = value_type (factor);
215- value_type one (1 );
212+ auto factor_t = T (factor);
213+ T one (1 );
216214
217215 blas::gemm (blas::Layout::ColMajor, gemm_helper.right_op (),
218216 gemm_helper.left_op (), n, m, k, factor_t ,
@@ -238,8 +236,8 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
238236
239237 // copy data
240238 auto &queue = blasqueue_for (result.range ());
241- blas::copy (result.size (), detail::device_data (arg), 1 ,
242- detail::device_data (result), 1 , queue);
239+ blas::copy (result.size (), detail::device_data (arg), T ( 1 ) ,
240+ detail::device_data (result), T ( 1 ) , queue);
243241 device::sync_madness_task_with (stream);
244242 return result;
245243}
@@ -266,19 +264,14 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
266264 detail::to_device (result);
267265
268266 // copy data
269- blas::copy (result.size (), detail::device_data (arg), 1 ,
270- detail::device_data (result), 1 , queue);
267+ blas::copy (result.size (), detail::device_data (arg), T ( 1 ) ,
268+ detail::device_data (result), T ( 1 ) , queue);
271269 device::sync_madness_task_with (stream);
272270 return result;
273271}
274272
275273template <typename T, typename Index>
276274UMTensor<T> &shift_to (UMTensor<T> &arg, const Index &bound_shift) {
277- // although shift_to is currently fine on shared objects since ranges are
278- // not shared, this will change in the future
279- #ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED
280- TA_ASSERT (data_.use_count () <= 1 );
281- #endif
282275 const_cast <TiledArray::Range &>(arg.range ()).inplace_shift (bound_shift);
283276 return arg;
284277}
@@ -303,8 +296,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
303296 detail::to_device (result);
304297
305298 // invoke permute function from librett
306- using value_type = UMTensor<T>::value_type;
307- librett_permute (const_cast <value_type *>(detail::device_data (arg)),
299+ librett_permute (const_cast <T *>(detail::device_data (arg)),
308300 detail::device_data (result), arg.range (), perm, stream);
309301 device::sync_madness_task_with (stream);
310302 return result;
@@ -369,8 +361,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
369361
370362template <typename T>
371363UMTensor<T> neg (const UMTensor<T> &arg) {
372- using value_type = UMTensor<T>::value_type;
373- return scale (arg, value_type (-1.0 ));
364+ return scale (arg, T (-1.0 ));
374365}
375366
376367template <typename T, typename Perm>
@@ -382,8 +373,7 @@ UMTensor<T> neg(const UMTensor<T> &arg, const Perm &perm) {
382373
383374template <typename T>
384375UMTensor<T> &neg_to (UMTensor<T> &arg) {
385- using value_type = UMTensor<T>::value_type;
386- return scale_to (arg, value_type (-1.0 ));
376+ return scale_to (arg, T (-1.0 ));
387377}
388378
389379// /
@@ -402,11 +392,10 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
402392 detail::to_device (result);
403393
404394 // result = arg1 + arg2
405- using value_type = typename UMTensor<T>::value_type;
406- blas::copy (result.size (), detail::device_data (arg1), 1 ,
407- detail::device_data (result), 1 , queue);
408- blas::axpy (result.size (), value_type (1 ), detail::device_data (arg2), 1 ,
409- detail::device_data (result), 1 , queue);
395+ blas::copy (result.size (), detail::device_data (arg1), T (1 ),
396+ detail::device_data (result), T (1 ), queue);
397+ blas::axpy (result.size (), T (1 ), detail::device_data (arg2), T (1 ),
398+ detail::device_data (result), T (1 ), queue);
410399 device::sync_madness_task_with (stream);
411400 return result;
412401}
@@ -449,9 +438,8 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
449438 detail::to_device (arg);
450439
451440 // result += arg
452- using value_type = typename UMTensor<T>::value_type;
453- blas::axpy (result.size (), value_type (1 ), detail::device_data (arg), 1 ,
454- detail::device_data (result), 1 , queue);
441+ blas::axpy (result.size (), T (1 ), detail::device_data (arg), T (1 ),
442+ detail::device_data (result), T (1 ), queue);
455443 device::sync_madness_task_with (stream);
456444 return result;
457445}
@@ -480,11 +468,10 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
480468 detail::to_device (result);
481469
482470 // result = arg1 - arg2
483- using value_type = typename UMTensor<T>::value_type;
484- blas::copy (result.size (), detail::device_data (arg1), 1 ,
485- detail::device_data (result), 1 , queue);
486- blas::axpy (result.size (), value_type (-1 ), detail::device_data (arg2), 1 ,
487- detail::device_data (result), 1 , queue);
471+ blas::copy (result.size (), detail::device_data (arg1), T (1 ),
472+ detail::device_data (result), T (1 ), queue);
473+ blas::axpy (result.size (), T (-1 ), detail::device_data (arg2), T (1 ),
474+ detail::device_data (result), T (1 ), queue);
488475 device::sync_madness_task_with (stream);
489476 return result;
490477}
@@ -527,9 +514,8 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
527514 detail::to_device (arg);
528515
529516 // result -= arg
530- using value_type = typename UMTensor<T>::value_type;
531- blas::axpy (result.size (), value_type (-1 ), detail::device_data (arg), 1 ,
532- detail::device_data (result), 1 , queue);
517+ blas::axpy (result.size (), T (-1 ), detail::device_data (arg), T (1 ),
518+ detail::device_data (result), T (1 ), queue);
533519 device::sync_madness_task_with (stream);
534520 return result;
535521}
@@ -548,12 +534,10 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
548534
549535template <typename T>
550536UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
551- std::size_t n = arg1.size ();
552- TA_ASSERT (arg2.size () == n);
537+ TA_ASSERT (arg1.size () == arg2.size ());
553538
554539 auto stream = device::stream_for (arg1.range ());
555540
556- using value_type = typename UMTensor<T>::value_type;
557541 UMTensor<T> result (arg1.range ());
558542
559543 detail::to_device (arg1);
@@ -562,7 +546,7 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
562546
563547 // element-wise multiplication
564548 device::mult_kernel (detail::device_data (result), detail::device_data (arg1),
565- detail::device_data (arg2), n , stream);
549+ detail::device_data (arg2), arg1. size () , stream);
566550 device::sync_madness_task_with (stream);
567551 return result;
568552}
@@ -599,16 +583,14 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
599583template <typename T>
600584UMTensor<T> &mult_to (UMTensor<T> &result, const UMTensor<T> &arg) {
601585 auto stream = device::stream_for (result.range ());
602-
603- std::size_t n = result.size ();
604- TA_ASSERT (n == arg.size ());
586+ TA_ASSERT (result.size () == arg.size ());
605587
606588 detail::to_device (result);
607589 detail::to_device (arg);
608590
609591 // in-place element-wise multiplication
610592 device::mult_to_kernel (detail::device_data (result), detail::device_data (arg),
611- n , stream);
593+ result. size () , stream);
612594
613595 device::sync_madness_task_with (stream);
614596 return result;
@@ -627,19 +609,17 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg,
627609// /
628610
629611template <typename T>
630- typename UMTensor<T>::value_type dot (const UMTensor<T> &arg1,
631- const UMTensor<T> &arg2) {
612+ T dot (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
632613 auto &queue = blasqueue_for (arg1.range ());
633614 const auto stream = device::Stream (queue.device (), queue.stream ());
634615
635616 detail::to_device (arg1);
636617 detail::to_device (arg2);
637618
638619 // compute dot product using device BLAS
639- using value_type = typename UMTensor<T>::value_type;
640- value_type result = value_type (0 );
641- blas::dot (arg1.size (), detail::device_data (arg1), 1 ,
642- detail::device_data (arg2), 1 , &result, queue);
620+ auto result = T (0 );
621+ blas::dot (arg1.size (), detail::device_data (arg1), T (1 ),
622+ detail::device_data (arg2), T (1 ), &result, queue);
643623 device::sync_madness_task_with (stream);
644624 return result;
645625}
@@ -649,28 +629,27 @@ typename UMTensor<T>::value_type dot(const UMTensor<T> &arg1,
649629// /
650630
651631template <typename T>
652- typename UMTensor<T>::value_type squared_norm (const UMTensor<T> &arg) {
632+ T squared_norm (const UMTensor<T> &arg) {
653633 auto &queue = blasqueue_for (arg.range ());
654634 const auto stream = device::Stream (queue.device (), queue.stream ());
655635
656636 detail::to_device (arg);
657637
658638 // compute squared norm using dot
659- using value_type = typename UMTensor<T>::value_type;
660- value_type result = value_type (0 );
661- blas::dot (arg.size (), detail::device_data (arg), 1 , detail::device_data (arg),
662- 1 , &result, queue);
639+ auto result = T (0 );
640+ blas::dot (arg.size (), detail::device_data (arg), T (1 ),
641+ detail::device_data (arg), T (1 ), &result, queue);
663642 device::sync_madness_task_with (stream);
664643 return result;
665644}
666645
667646template <typename T>
668- typename UMTensor<T>::value_type norm (const UMTensor<T> &arg) {
647+ T norm (const UMTensor<T> &arg) {
669648 return std::sqrt (squared_norm (arg));
670649}
671650
672651template <typename T>
673- typename UMTensor<T>::value_type sum (const UMTensor<T> &arg) {
652+ T sum (const UMTensor<T> &arg) {
674653 detail::to_device (arg);
675654 auto stream = device::stream_for (arg.range ());
676655 auto result =
@@ -680,7 +659,7 @@ typename UMTensor<T>::value_type sum(const UMTensor<T> &arg) {
680659}
681660
682661template <typename T>
683- typename UMTensor<T>::value_type product (const UMTensor<T> &arg) {
662+ T product (const UMTensor<T> &arg) {
684663 detail::to_device (arg);
685664 auto stream = device::stream_for (arg.range ());
686665 auto result =
@@ -690,7 +669,7 @@ typename UMTensor<T>::value_type product(const UMTensor<T> &arg) {
690669}
691670
692671template <typename T>
693- typename UMTensor<T>::value_type max (const UMTensor<T> &arg) {
672+ T max (const UMTensor<T> &arg) {
694673 detail::to_device (arg);
695674 auto stream = device::stream_for (arg.range ());
696675 auto result =
@@ -700,7 +679,7 @@ typename UMTensor<T>::value_type max(const UMTensor<T> &arg) {
700679}
701680
702681template <typename T>
703- typename UMTensor<T>::value_type min (const UMTensor<T> &arg) {
682+ T min (const UMTensor<T> &arg) {
704683 detail::to_device (arg);
705684 auto stream = device::stream_for (arg.range ());
706685 auto result =
@@ -710,7 +689,7 @@ typename UMTensor<T>::value_type min(const UMTensor<T> &arg) {
710689}
711690
712691template <typename T>
713- typename UMTensor<T>::value_type abs_max (const UMTensor<T> &arg) {
692+ T abs_max (const UMTensor<T> &arg) {
714693 detail::to_device (arg);
715694 auto stream = device::stream_for (arg.range ());
716695 auto result =
@@ -720,12 +699,13 @@ typename UMTensor<T>::value_type abs_max(const UMTensor<T> &arg) {
720699}
721700
722701template <typename T>
723- typename UMTensor<T>::value_type abs_min (const UMTensor<T> &arg) {
702+ T abs_min (const UMTensor<T> &arg) {
724703 detail::to_device (arg);
725704 auto stream = device::stream_for (arg.range ());
726705 auto result =
727706 device::absmin_kernel (detail::device_data (arg), arg.size (), stream);
728707 device::sync_madness_task_with (stream);
708+
729709 return result;
730710}
731711
0 commit comments