Skip to content

Commit 6a52cf1

Browse files
committed
Fixed bug in reading/writing PT files
1 parent 881c892 commit 6a52cf1

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

include/iganet.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2021,7 +2021,7 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
20212021
[&](auto &&...elems) {
20222022
std::size_t counter = 0;
20232023
(elems.write(archive,
2024-
key + ".output[" + std::to_string(counter++) + "]"),
2024+
key + ".collpts[" + std::to_string(counter++) + "]"),
20252025
...);
20262026
},
20272027
Base::collPts());
@@ -2067,13 +2067,18 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
20672067
[&](auto &&...elems) {
20682068
std::size_t counter = 0;
20692069
(elems.read(archive,
2070-
key + ".output[" + std::to_string(counter++) + "]"),
2070+
key + ".collpts[" + std::to_string(counter++) + "]"),
20712071
...);
20722072
},
20732073
Base::collPts());
20742074
}
20752075

20762076
net_->read(archive, key + ".net");
2077+
2078+
auto o = iganet::Options<typename Base::value_type>{};
2079+
for (auto& layer : net_->layers_)
2080+
layer->to(o.device(), o.dtype(), true);
2081+
20772082
torch::serialize::InputArchive archive_net;
20782083
archive.read(key + ".net.data", archive_net);
20792084
net_->load(archive_net);

0 commit comments

Comments
 (0)