Skip to content

Commit 720f678

Browse files
committed
UMTensor: refactor serialization implementation
1 parent 60d4abd commit 720f678

File tree

1 file changed

+53
-16
lines changed

1 file changed

+53
-16
lines changed

src/TiledArray/device/um_tensor.h

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -735,31 +735,68 @@ typename UMTensor<T>::value_type abs_min(const UMTensor<T> &arg) {
735735
namespace madness {
736736
namespace archive {
737737

738-
template <class Archive, typename T>
738+
template <typename Archive, typename T>
739+
struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
740+
static inline void store(const Archive &ar,
741+
const TiledArray::UMTensor<T> &t) {
742+
ar & t.range();
743+
ar & t.nbatch();
744+
if (t.range().volume() > 0) {
745+
auto stream = TiledArray::device::stream_for(t.range());
746+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(t,
747+
stream);
748+
ar &madness::archive::wrap(t.data(), t.range().volume() * t.nbatch());
749+
}
750+
}
751+
};
752+
753+
template <typename Archive, typename T>
739754
struct ArchiveLoadImpl<Archive, TiledArray::UMTensor<T>> {
740755
static inline void load(const Archive &ar, TiledArray::UMTensor<T> &t) {
741756
TiledArray::Range range{};
757+
size_t nbatch{};
742758
ar & range;
743-
759+
ar & nbatch;
744760
if (range.volume() > 0) {
745-
t = TiledArray::UMTensor<T>(std::move(range));
746-
ar &madness::archive::wrap(t.data(), t.size());
747-
} else {
748-
t = TiledArray::UMTensor<T>{};
761+
t = TiledArray::UMTensor<T>(std::move(range), nbatch);
762+
ar &madness::archive::wrap(t.data(), range.volume() * nbatch);
749763
}
750764
}
751765
};
752766

753-
template <class Archive, typename T>
754-
struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
755-
static inline void store(const Archive &ar,
756-
const TiledArray::UMTensor<T> &t) {
757-
ar & t.range();
758-
if (t.range().volume() > 0) {
759-
ar &madness::archive::wrap(t.data(), t.size());
760-
}
761-
}
762-
};
767+
// template <class Archive, typename T>
768+
// struct ArchiveLoadImpl<Archive, TiledArray::UMTensor<T>> {
769+
// static inline void load(const Archive &ar, TiledArray::UMTensor<T> &t) {
770+
// TiledArray::Range range{};
771+
// TiledArray::UMTensor<T> data;
772+
// ar & range & data;
773+
// t = TiledArray::UMTensor<T>(std::move(range), std::move(data));
774+
775+
// // if (range.volume() > 0) {
776+
// // t = TiledArray::UMTensor<T>(std::move(range));
777+
// // ar & madness::archive::wrap(t.data(), t.size());
778+
// // } else {
779+
// // t = TiledArray::UMTensor<T>{};
780+
// // }
781+
// }
782+
// };
783+
784+
// template <class Archive, typename T>
785+
// struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
786+
// static inline void store(const Archive &ar,
787+
// const TiledArray::UMTensor<T> &t) {
788+
// ar & t.range();
789+
// auto stream = TiledArray::device::stream_for(t.range());
790+
// TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
791+
// t, stream);
792+
793+
// ar & t.range() & t;
794+
795+
// // if (t.range().volume() > 0) {
796+
// // ar &madness::archive::wrap(t.data(), t.size());
797+
// // }
798+
// }
799+
// };
763800

764801
} // namespace archive
765802
} // namespace madness

0 commit comments

Comments
 (0)