Skip to content

Commit bbe8c2f

Browse files
integrate early stopping logic in GaussSiedel optimization func
1 parent 28eb004 commit bbe8c2f

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

Libs/Optimize/GradientDescentOptimizer.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@ const int global_iteration = 1;
1919
#include "Libs/Optimize/Utils/MemoryUsage.h"
2020

2121
#include <Profiling.h>
22+
#include "Libs/Optimize/Function/EarlyStop/EarlyStopping.h"
23+
#include <Logging.h>
2224

2325
namespace shapeworks {
24-
GradientDescentOptimizer::GradientDescentOptimizer() {
26+
GradientDescentOptimizer::GradientDescentOptimizer()
27+
: m_EarlyStopping()
28+
{
2529
m_StopOptimization = false;
2630
m_NumberOfIterations = 0;
2731
m_MaximumNumberOfIterations = 0;
@@ -48,6 +52,33 @@ void GradientDescentOptimizer::ResetTimeStepVectors() {
4852
}
4953
}
5054

55+
void GradientDescentOptimizer::SetEarlyStoppingConfig(const EarlyStoppingConfig& config) {
56+
m_EarlyStopping.SetConfigParams(
57+
config.frequency,
58+
config.window_size,
59+
config.threshold,
60+
config.strategy,
61+
config.ema_alpha,
62+
config.enable_logging,
63+
config.logger_name,
64+
config.warmup_iters
65+
);
66+
m_EarlyStoppingEnabled = config.enabled;
67+
}
68+
69+
void GradientDescentOptimizer::InitializeEarlyStoppingScoreFunction(
70+
const ParticleSystemType* p) {
71+
bool early_stopping_status = m_EarlyStopping.SetControlShapes(p);
72+
if (early_stopping_status == false) {
73+
SW_WARN(
74+
"Early stopping has been forcibly disabled. Possible causes: no fixed "
75+
"shapes/domains "
76+
"to fit PCA, or PCA fitting failed. Check logs for details.");
77+
}
78+
m_EarlyStoppingScoreFunctionReady = early_stopping_status;
79+
m_EarlyStoppingEnabled = early_stopping_status;
80+
}
81+
5182
void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
5283
TIME_SCOPE("GradientDescentOptimizer");
5384
/// uncomment this to run single threaded
@@ -78,6 +109,7 @@ void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
78109
unsigned int counter = 0;
79110

80111
double maxchange = 0.0;
112+
if (m_EarlyStoppingEnabled) m_EarlyStopping.reset(); // reset early stopping cache before starting optimization
81113
while (m_StopOptimization == false) // iterations loop
82114
{
83115
TIME_SCOPE("optimizer_iteration");
@@ -93,6 +125,16 @@ void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
93125

94126
const auto accTimerBegin = std::chrono::steady_clock::now();
95127
m_GradientFunction->SetParticleSystem(m_ParticleSystem);
128+
if (m_EarlyStoppingEnabled && !m_EarlyStoppingScoreFunctionReady) {
129+
bool early_stopping_status = m_EarlyStopping.SetControlShapes(m_ParticleSystem);
130+
if (early_stopping_status == false) {
131+
SW_WARN(
132+
"Early stopping has been forcibly disabled. Possible causes: no fixed shapes/domains "
133+
"to fit PCA, or PCA fitting failed. Check logs.");
134+
}
135+
m_EarlyStoppingEnabled = early_stopping_status; // forcibly turn off early stopping if no fixed domains present
136+
m_EarlyStoppingScoreFunctionReady = early_stopping_status;
137+
}
96138

97139
TIME_START("gradient_before_iteration");
98140
if (counter % global_iteration == 0) m_GradientFunction->BeforeIteration();
@@ -234,6 +276,17 @@ void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
234276
m_StopOptimization = true;
235277
}
236278

279+
if (m_EarlyStoppingEnabled) {
280+
m_EarlyStopping.update(m_NumberOfIterations, m_ParticleSystem);
281+
if (m_EarlyStopping.ShouldStop()) {
282+
std::cerr << "Early stopping triggered at optimization iteration "
283+
<< m_NumberOfIterations << std::endl;
284+
SW_LOG("Early stopping triggered at optimization iteration {}",
285+
m_NumberOfIterations);
286+
m_StopOptimization = true;
287+
}
288+
}
289+
237290
} // end while stop optimization
238291
}
239292

Libs/Optimize/GradientDescentOptimizer.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "Libs/Optimize/Domain/ImageDomainWithGradients.h"
88
#include "Libs/Optimize/Function/VectorFunction.h"
9+
#include "Libs/Optimize/Function/EarlyStop/EarlyStopping.h"
10+
#include "EarlyStoppingConfig.h"
911
#include "ParticleSystem.h"
1012
#include "itkObject.h"
1113
#include "itkObjectFactory.h"
@@ -63,6 +65,8 @@ class GradientDescentOptimizer : public itk::Object {
6365
/** Start the optimization. */
6466
void StartOptimization() { this->StartAdaptiveGaussSeidelOptimization(); }
6567
void StartAdaptiveGaussSeidelOptimization();
68+
void SetEarlyStoppingConfig(const EarlyStoppingConfig& config);
69+
void InitializeEarlyStoppingScoreFunction(const ParticleSystemType* p);
6670

6771
void AugmentedLagrangianConstraints(VectorType& gradient, const PointType& pt, const size_t& dom,
6872
const double& maximumUpdateAllowed, size_t index);
@@ -132,6 +136,9 @@ class GradientDescentOptimizer : public itk::Object {
132136
double m_TimeStep;
133137
std::vector<std::vector<double> > m_TimeSteps;
134138
unsigned int m_verbosity;
139+
EarlyStopping m_EarlyStopping;
140+
bool m_EarlyStoppingEnabled = false;
141+
bool m_EarlyStoppingScoreFunctionReady = false;
135142

136143
// Adaptive Initialization variables
137144
bool m_initialization_mode = false;

0 commit comments

Comments
 (0)