@@ -1218,10 +1218,12 @@ class IgANet : public IgABase<GeometryMap, Variable>,
12181218 loss.isnan ().item <bool >()) {
12191219 Log (log::info) << " Total epochs: " << epoch
12201220 << " , loss: " << current_loss << std::endl;
1221- break ;
1221+ return ;
12221222 }
12231223 previous_loss = current_loss;
12241224 }
1225+ Log (log::info) << " Max epochs reached: " << options_.max_epoch ()
1226+ << " , loss: " << previous_loss << std::endl;
12251227 }
12261228
12271229 // / @brief Trains the IgANet
@@ -1239,7 +1241,7 @@ class IgANet : public IgABase<GeometryMap, Variable>,
12391241 // Loop over epochs
12401242 for (int64_t epoch = 0 ; epoch != options_.max_epoch (); ++epoch) {
12411243
1242- typename Base::value_type Loss (0 );
1244+ typename Base::value_type current_loss (0 );
12431245
12441246 for (auto &batch : loader) {
12451247 inputs = batch.data ;
@@ -1298,29 +1300,24 @@ class IgANet : public IgABase<GeometryMap, Variable>,
12981300 // Update the parameters based on the calculated gradients
12991301 opt_->step (closure);
13001302
1301- Loss += loss.item <typename Base::value_type>();
1302- }
1303-
1304- Log (log::verbose) << " Epoch " << std::to_string (epoch) << " : " << Loss
1305- << std::endl;
1306-
1307- if (Loss < options_.min_loss ()) {
1308- Log (log::info) << " Total epochs: " << epoch << " , loss: " << Loss
1309- << std::endl;
1310- break ;
1303+ current_loss += loss.item <typename Base::value_type>();
13111304 }
1305+ Log (log::verbose) << " Epoch " << std::to_string (epoch) << " : "
1306+ << current_loss << std::endl;
13121307
1313- if (Loss == previous_loss) {
1314- Log (log::info) << " Total epochs: " << epoch << " , loss: " << Loss
1315- << std::endl;
1316- break ;
1308+ if (current_loss < options_.min_loss () ||
1309+ std::abs (current_loss - previous_loss) < options_.min_loss_change () ||
1310+ std::abs (current_loss - previous_loss) / current_loss <
1311+ options_.min_loss_rel_change () ||
1312+ loss.isnan ().item <bool >()) {
1313+ Log (log::info) << " Total epochs: " << epoch
1314+ << " , loss: " << current_loss << std::endl;
1315+ return ;
13171316 }
1318- previous_loss = Loss;
1319-
1320- if (epoch == options_.max_epoch () - 1 )
1321- Log (log::warning) << " Total epochs: " << epoch << " , loss: " << Loss
1322- << std::endl;
1317+ previous_loss = current_loss;
13231318 }
1319+ Log (log::info) << " Max epochs reached: " << options_.max_epoch ()
1320+ << " , loss: " << previous_loss << std::endl;
13241321 }
13251322
13261323 // / @brief Evaluate IgANet
@@ -1825,10 +1822,12 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
18251822 loss.isnan ().item <bool >()) {
18261823 Log (log::info) << " Total epochs: " << epoch
18271824 << " , loss: " << current_loss << std::endl;
1828- break ;
1825+ return ;
18291826 }
18301827 previous_loss = current_loss;
18311828 }
1829+ Log (log::info) << " Max epochs reached: " << options_.max_epoch ()
1830+ << " , loss: " << previous_loss << std::endl;
18321831 }
18331832
18341833 // / @brief Trains the IgANet
@@ -1846,7 +1845,7 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
18461845 // Loop over epochs
18471846 for (int64_t epoch = 0 ; epoch != options_.max_epoch (); ++epoch) {
18481847
1849- typename Base::value_type Loss (0 );
1848+ typename Base::value_type current_loss (0 );
18501849
18511850 for (auto &batch : loader) {
18521851 inputs = batch.data ;
@@ -1905,29 +1904,24 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
19051904 // Update the parameters based on the calculated gradients
19061905 opt_->step (closure);
19071906
1908- Loss += loss.item <typename Base::value_type>();
1909- }
1910-
1911- Log (log::verbose) << " Epoch " << std::to_string (epoch) << " : " << Loss
1912- << std::endl;
1913-
1914- if (Loss < options_.min_loss ()) {
1915- Log (log::info) << " Total epochs: " << epoch << " , loss: " << Loss
1916- << std::endl;
1917- break ;
1907+ current_loss += loss.item <typename Base::value_type>();
19181908 }
1909+ Log (log::verbose) << " Epoch " << std::to_string (epoch) << " : "
1910+ << current_loss << std::endl;
19191911
1920- if (Loss == previous_loss) {
1921- Log (log::info) << " Total epochs: " << epoch << " , loss: " << Loss
1922- << std::endl;
1923- break ;
1912+ if (current_loss < options_.min_loss () ||
1913+ std::abs (current_loss - previous_loss) < options_.min_loss_change () ||
1914+ std::abs (current_loss - previous_loss) / current_loss <
1915+ options_.min_loss_rel_change () ||
1916+ loss.isnan ().item <bool >()) {
1917+ Log (log::info) << " Total epochs: " << epoch
1918+ << " , loss: " << current_loss << std::endl;
1919+ return ;
19241920 }
1925- previous_loss = Loss;
1926-
1927- if (epoch == options_.max_epoch () - 1 )
1928- Log (log::warning) << " Total epochs: " << epoch << " , loss: " << Loss
1929- << std::endl;
1921+ previous_loss = current_loss;
19301922 }
1923+ Log (log::info) << " Max epochs reached: " << options_.max_epoch ()
1924+ << " , loss: " << previous_loss << std::endl;
19311925 }
19321926
19331927 // / @brief Evaluate IgANet
0 commit comments