Skip to content

Commit 8ba0a1f

Browse files
committed
Fixed bug in log output
1 parent 2cb6d2d commit 8ba0a1f

File tree

1 file changed

+36
-42
lines changed

1 file changed

+36
-42
lines changed

include/iganet.hpp

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,10 +1218,12 @@ class IgANet : public IgABase<GeometryMap, Variable>,
12181218
loss.isnan().item<bool>()) {
12191219
Log(log::info) << "Total epochs: " << epoch
12201220
<< ", loss: " << current_loss << std::endl;
1221-
break;
1221+
return;
12221222
}
12231223
previous_loss = current_loss;
12241224
}
1225+
Log(log::info) << "Max epochs reached: " << options_.max_epoch()
1226+
<< ", loss: " << previous_loss << std::endl;
12251227
}
12261228

12271229
/// @brief Trains the IgANet
@@ -1239,7 +1241,7 @@ class IgANet : public IgABase<GeometryMap, Variable>,
12391241
// Loop over epochs
12401242
for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
12411243

1242-
typename Base::value_type Loss(0);
1244+
typename Base::value_type current_loss(0);
12431245

12441246
for (auto &batch : loader) {
12451247
inputs = batch.data;
@@ -1298,29 +1300,24 @@ class IgANet : public IgABase<GeometryMap, Variable>,
12981300
// Update the parameters based on the calculated gradients
12991301
opt_->step(closure);
13001302

1301-
Loss += loss.item<typename Base::value_type>();
1302-
}
1303-
1304-
Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": " << Loss
1305-
<< std::endl;
1306-
1307-
if (Loss < options_.min_loss()) {
1308-
Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1309-
<< std::endl;
1310-
break;
1303+
current_loss += loss.item<typename Base::value_type>();
13111304
}
1305+
Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": "
1306+
<< current_loss << std::endl;
13121307

1313-
if (Loss == previous_loss) {
1314-
Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1315-
<< std::endl;
1316-
break;
1308+
if (current_loss < options_.min_loss() ||
1309+
std::abs(current_loss - previous_loss) < options_.min_loss_change() ||
1310+
std::abs(current_loss - previous_loss) / current_loss <
1311+
options_.min_loss_rel_change() ||
1312+
loss.isnan().item<bool>()) {
1313+
Log(log::info) << "Total epochs: " << epoch
1314+
<< ", loss: " << current_loss << std::endl;
1315+
return;
13171316
}
1318-
previous_loss = Loss;
1319-
1320-
if (epoch == options_.max_epoch() - 1)
1321-
Log(log::warning) << "Total epochs: " << epoch << ", loss: " << Loss
1322-
<< std::endl;
1317+
previous_loss = current_loss;
13231318
}
1319+
Log(log::info) << "Max epochs reached: " << options_.max_epoch()
1320+
<< ", loss: " << previous_loss << std::endl;
13241321
}
13251322

13261323
/// @brief Evaluate IgANet
@@ -1825,10 +1822,12 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
18251822
loss.isnan().item<bool>()) {
18261823
Log(log::info) << "Total epochs: " << epoch
18271824
<< ", loss: " << current_loss << std::endl;
1828-
break;
1825+
return;
18291826
}
18301827
previous_loss = current_loss;
18311828
}
1829+
Log(log::info) << "Max epochs reached: " << options_.max_epoch()
1830+
<< ", loss: " << previous_loss << std::endl;
18321831
}
18331832

18341833
/// @brief Trains the IgANet
@@ -1846,7 +1845,7 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
18461845
// Loop over epochs
18471846
for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
18481847

1849-
typename Base::value_type Loss(0);
1848+
typename Base::value_type current_loss(0);
18501849

18511850
for (auto &batch : loader) {
18521851
inputs = batch.data;
@@ -1905,29 +1904,24 @@ class IgANet2 : public IgABase2<Inputs, Outputs, CollPts>,
19051904
// Update the parameters based on the calculated gradients
19061905
opt_->step(closure);
19071906

1908-
Loss += loss.item<typename Base::value_type>();
1909-
}
1910-
1911-
Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": " << Loss
1912-
<< std::endl;
1913-
1914-
if (Loss < options_.min_loss()) {
1915-
Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1916-
<< std::endl;
1917-
break;
1907+
current_loss += loss.item<typename Base::value_type>();
19181908
}
1909+
Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": "
1910+
<< current_loss << std::endl;
19191911

1920-
if (Loss == previous_loss) {
1921-
Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1922-
<< std::endl;
1923-
break;
1912+
if (current_loss < options_.min_loss() ||
1913+
std::abs(current_loss - previous_loss) < options_.min_loss_change() ||
1914+
std::abs(current_loss - previous_loss) / current_loss <
1915+
options_.min_loss_rel_change() ||
1916+
loss.isnan().item<bool>()) {
1917+
Log(log::info) << "Total epochs: " << epoch
1918+
<< ", loss: " << current_loss << std::endl;
1919+
return;
19241920
}
1925-
previous_loss = Loss;
1926-
1927-
if (epoch == options_.max_epoch() - 1)
1928-
Log(log::warning) << "Total epochs: " << epoch << ", loss: " << Loss
1929-
<< std::endl;
1921+
previous_loss = current_loss;
19301922
}
1923+
Log(log::info) << "Max epochs reached: " << options_.max_epoch()
1924+
<< ", loss: " << previous_loss << std::endl;
19311925
}
19321926

19331927
/// @brief Evaluate IgANet

0 commit comments

Comments
 (0)