@@ -735,31 +735,68 @@ typename UMTensor<T>::value_type abs_min(const UMTensor<T> &arg) {
735735namespace madness {
736736namespace archive {
737737
738- template <class Archive , typename T>
738+ template <typename Archive, typename T>
739+ struct ArchiveStoreImpl <Archive, TiledArray::UMTensor<T>> {
740+ static inline void store (const Archive &ar,
741+ const TiledArray::UMTensor<T> &t) {
742+ ar & t.range ();
743+ ar & t.nbatch ();
744+ if (t.range ().volume () > 0 ) {
745+ auto stream = TiledArray::device::stream_for (t.range ());
746+ TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(t,
747+ stream);
748+ ar &madness::archive::wrap (t.data (), t.range ().volume () * t.nbatch ());
749+ }
750+ }
751+ };
752+
753+ template <typename Archive, typename T>
739754struct ArchiveLoadImpl <Archive, TiledArray::UMTensor<T>> {
740755 static inline void load (const Archive &ar, TiledArray::UMTensor<T> &t) {
741756 TiledArray::Range range{};
757+ size_t nbatch{};
742758 ar & range;
743-
759+ ar & nbatch;
744760 if (range.volume () > 0 ) {
745- t = TiledArray::UMTensor<T>(std::move (range));
746- ar &madness::archive::wrap (t.data (), t.size ());
747- } else {
748- t = TiledArray::UMTensor<T>{};
761+ t = TiledArray::UMTensor<T>(std::move (range), nbatch);
762+ ar &madness::archive::wrap (t.data (), range.volume () * nbatch);
749763 }
750764 }
751765};
752766
753- template <class Archive , typename T>
754- struct ArchiveStoreImpl <Archive, TiledArray::UMTensor<T>> {
755- static inline void store (const Archive &ar,
756- const TiledArray::UMTensor<T> &t) {
757- ar & t.range ();
758- if (t.range ().volume () > 0 ) {
759- ar &madness::archive::wrap (t.data (), t.size ());
760- }
761- }
762- };
767+ // template <class Archive, typename T>
768+ // struct ArchiveLoadImpl<Archive, TiledArray::UMTensor<T>> {
769+ // static inline void load(const Archive &ar, TiledArray::UMTensor<T> &t) {
770+ // TiledArray::Range range{};
771+ // TiledArray::UMTensor<T> data;
772+ // ar & range & data;
773+ // t = TiledArray::UMTensor<T>(std::move(range), std::move(data));
774+
775+ // // if (range.volume() > 0) {
776+ // // t = TiledArray::UMTensor<T>(std::move(range));
777+ // // ar & madness::archive::wrap(t.data(), t.size());
778+ // // } else {
779+ // // t = TiledArray::UMTensor<T>{};
780+ // // }
781+ // }
782+ // };
783+
784+ // template <class Archive, typename T>
785+ // struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
786+ // static inline void store(const Archive &ar,
787+ // const TiledArray::UMTensor<T> &t) {
788+ // ar & t.range();
789+ // auto stream = TiledArray::device::stream_for(t.range());
790+ // TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
791+ // t, stream);
792+
793+ // ar & t.range() & t;
794+
795+ // // if (t.range().volume() > 0) {
796+ // // ar &madness::archive::wrap(t.data(), t.size());
797+ // // }
798+ // }
799+ // };
763800
764801} // namespace archive
765802} // namespace madness
0 commit comments