Skip to content

Commit 747aa22

Browse files
committed
Fixed small bug in reading B-splines from PT files
1 parent 6a52cf1 commit 747aa22

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

include/iganet.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,8 @@ class IgANetGeneratorImpl : public torch::nn::Module {
764764
const std::string &key = "iganet") {
765765
torch::Tensor layers, in_features, outputs_features, bias, activation;
766766

767+
auto options = iganet::Options<real_t>{};
768+
767769
archive.read(key + ".layers", layers);
768770
for (int64_t i = 0; i < layers.item<int64_t>(); ++i) {
769771
archive.read(key + ".layer[" + std::to_string(i) + "].in_features",
@@ -776,8 +778,9 @@ class IgANetGeneratorImpl : public torch::nn::Module {
776778
torch::nn::Linear(
777779
torch::nn::LinearOptions(in_features.item<int64_t>(),
778780
outputs_features.item<int64_t>())
779-
.bias(bias.item<bool>()))));
780-
781+
.bias(bias.item<bool>()))));
782+
layers_.back()->to(options.device(), options.dtype(), true);
783+
781784
archive.read(key + ".layer[" + std::to_string(i) + "].activation.type",
782785
activation);
783786
switch (static_cast<enum activation>(activation.item<int64_t>())) {
@@ -2073,12 +2076,7 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
20732076
Base::collPts());
20742077
}
20752078

2076-
net_->read(archive, key + ".net");
2077-
2078-
auto o = iganet::Options<typename Base::value_type>{};
2079-
for (auto& layer : net_->layers_)
2080-
layer->to(o.device(), o.dtype(), true);
2081-
2079+
net_->read(archive, key + ".net");
20822080
torch::serialize::InputArchive archive_net;
20832081
archive.read(key + ".net.data", archive_net);
20842082
net_->load(archive_net);

0 commit comments

Comments
 (0)