@@ -764,6 +764,8 @@ class IgANetGeneratorImpl : public torch::nn::Module {
764764 const std::string &key = " iganet" ) {
765765 torch::Tensor layers, in_features, outputs_features, bias, activation;
766766
767+ auto options = iganet::Options<real_t >{};
768+
767769 archive.read (key + " .layers" , layers);
768770 for (int64_t i = 0 ; i < layers.item <int64_t >(); ++i) {
769771 archive.read (key + " .layer[" + std::to_string (i) + " ].in_features" ,
@@ -776,8 +778,9 @@ class IgANetGeneratorImpl : public torch::nn::Module {
776778 torch::nn::Linear (
777779 torch::nn::LinearOptions (in_features.item <int64_t >(),
778780 outputs_features.item <int64_t >())
779- .bias (bias.item <bool >()))));
780-
781+ .bias (bias.item <bool >()))));
782+ layers_.back ()->to (options.device (), options.dtype (), true );
783+
781784 archive.read (key + " .layer[" + std::to_string (i) + " ].activation.type" ,
782785 activation);
783786 switch (static_cast <enum activation>(activation.item <int64_t >())) {
@@ -2073,12 +2076,7 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
20732076 Base::collPts ());
20742077 }
20752078
2076- net_->read (archive, key + " .net" );
2077-
2078- auto o = iganet::Options<typename Base::value_type>{};
2079- for (auto & layer : net_->layers_ )
2080- layer->to (o.device (), o.dtype (), true );
2081-
2079+ net_->read (archive, key + " .net" );
20822080 torch::serialize::InputArchive archive_net;
20832081 archive.read (key + " .net.data" , archive_net);
20842082 net_->load (archive_net);
0 commit comments