@@ -19,9 +19,13 @@ const int global_iteration = 1;
1919#include " Libs/Optimize/Utils/MemoryUsage.h"
2020
2121#include < Profiling.h>
22+ #include " Libs/Optimize/Function/EarlyStop/EarlyStopping.h"
23+ #include < Logging.h>
2224
2325namespace shapeworks {
24- GradientDescentOptimizer::GradientDescentOptimizer () {
26+ GradientDescentOptimizer::GradientDescentOptimizer ()
27+ : m_EarlyStopping()
28+ {
2529 m_StopOptimization = false ;
2630 m_NumberOfIterations = 0 ;
2731 m_MaximumNumberOfIterations = 0 ;
@@ -48,6 +52,33 @@ void GradientDescentOptimizer::ResetTimeStepVectors() {
4852 }
4953}
5054
55+ void GradientDescentOptimizer::SetEarlyStoppingConfig (const EarlyStoppingConfig& config) {
56+ m_EarlyStopping.SetConfigParams (
57+ config.frequency ,
58+ config.window_size ,
59+ config.threshold ,
60+ config.strategy ,
61+ config.ema_alpha ,
62+ config.enable_logging ,
63+ config.logger_name ,
64+ config.warmup_iters
65+ );
66+ m_EarlyStoppingEnabled = config.enabled ;
67+ }
68+
69+ void GradientDescentOptimizer::InitializeEarlyStoppingScoreFunction (
70+ const ParticleSystemType* p) {
71+ bool early_stopping_status = m_EarlyStopping.SetControlShapes (p);
72+ if (early_stopping_status == false ) {
73+ SW_WARN (
74+ " Early stopping has been forcibly disabled. Possible causes: no fixed "
75+ " shapes/domains "
76+ " to fit PCA, or PCA fitting failed. Check logs for details." );
77+ }
78+ m_EarlyStoppingScoreFunctionReady = early_stopping_status;
79+ m_EarlyStoppingEnabled = early_stopping_status;
80+ }
81+
5182void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization () {
5283 TIME_SCOPE (" GradientDescentOptimizer" );
5384 // / uncomment this to run single threaded
@@ -78,6 +109,7 @@ void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
78109 unsigned int counter = 0 ;
79110
80111 double maxchange = 0.0 ;
112+ if (m_EarlyStoppingEnabled) m_EarlyStopping.reset (); // reset early stopping cache before starting optimization
81113 while (m_StopOptimization == false ) // iterations loop
82114 {
83115 TIME_SCOPE (" optimizer_iteration" );
@@ -93,6 +125,16 @@ void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
93125
94126 const auto accTimerBegin = std::chrono::steady_clock::now ();
95127 m_GradientFunction->SetParticleSystem (m_ParticleSystem);
128+ if (m_EarlyStoppingEnabled && !m_EarlyStoppingScoreFunctionReady) {
129+ bool early_stopping_status = m_EarlyStopping.SetControlShapes (m_ParticleSystem);
130+ if (early_stopping_status == false ) {
131+ SW_WARN (
132+ " Early stopping has been forcibly disabled. Possible causes: no fixed shapes/domains "
133+ " to fit PCA, or PCA fitting failed. Check logs." );
134+ }
135+ m_EarlyStoppingEnabled = early_stopping_status; // forcibly turn off early stopping if no fixed domains present
136+ m_EarlyStoppingScoreFunctionReady = early_stopping_status;
137+ }
96138
97139 TIME_START (" gradient_before_iteration" );
98140 if (counter % global_iteration == 0 ) m_GradientFunction->BeforeIteration ();
@@ -234,6 +276,17 @@ void GradientDescentOptimizer::StartAdaptiveGaussSeidelOptimization() {
234276 m_StopOptimization = true ;
235277 }
236278
279+ if (m_EarlyStoppingEnabled) {
280+ m_EarlyStopping.update (m_NumberOfIterations, m_ParticleSystem);
281+ if (m_EarlyStopping.ShouldStop ()) {
282+ std::cerr << " Early stopping triggered at optimization iteration "
283+ << m_NumberOfIterations << std::endl;
284+ SW_LOG (" Early stopping triggered at optimization iteration {}" ,
285+ m_NumberOfIterations);
286+ m_StopOptimization = true ;
287+ }
288+ }
289+
237290 } // end while stop optimization
238291}
239292
0 commit comments