Skip to content

Commit 739d264

Browse files
committed
UM: Add to_host functions for Tensor types. Also removes non-const version of to_device, since pre-fetching is not a mutating operation.
1 parent e6c3ba5 commit 739d264

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

src/TiledArray/device/btas_um_tensor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ void to_device(const TiledArray::btasUMTensorVarray<T> &tile) {
5454
tile.storage(), stream);
5555
}
5656

57+
/// pre-fetch memory to host
58+
template <typename T>
59+
void to_host(const TiledArray::btasUMTensorVarray<T> &tile) {
60+
auto stream = device::stream_for(tile.range());
61+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
62+
tile.storage(), stream);
63+
}
64+
5765
} // end of namespace detail
5866

5967
} // end of namespace TiledArray

src/TiledArray/device/um_tensor.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ void to_device(const UMTensor<T> &tensor) {
6161
stream);
6262
}
6363

64-
/// pre-fetch to device (non-const)
64+
/// pre-fetch to host
6565
template <typename T>
66-
void to_device(UMTensor<T> &tensor) {
66+
void to_host(const UMTensor<T> &tensor) {
6767
auto stream = device::stream_for(tensor.range());
68-
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
69-
stream);
68+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(tensor,
69+
stream);
7070
}
7171

7272
/// get device data pointer
@@ -643,8 +643,8 @@ T squared_norm(const UMTensor<T> &arg) {
643643

644644
// compute squared norm using dot
645645
auto result = T(0);
646-
blas::dot(arg.size(), detail::device_data(arg), 1,
647-
detail::device_data(arg), 1, &result, queue);
646+
blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg),
647+
1, &result, queue);
648648
device::sync_madness_task_with(stream);
649649
return result;
650650
}

0 commit comments

Comments
 (0)