1+ #include " EarlyStopping.h"
2+ #include " MorphologicalDeviationScore.h"
3+ #include < Eigen/Dense>
4+ #include < Profiling.h>
5+ #include < Logging.h>
6+
7+ namespace shapeworks {
8+
9+ EarlyStopping::EarlyStopping ()
10+ : frequency_(100 ),
11+ window_size_ (5 ),
12+ threshold_(0.0001 ),
13+ strategy_(EarlyStoppingStrategy::RelativeDifference),
14+ ema_alpha_(0.2 ),
15+ enable_logging_(false ),
16+ logger_name_(" early_stopping" ),
17+ last_checked_iter_(-1 ),
18+ warmup_iters_(1000 ),
19+ score_func_() {}
20+
21+ // ---------------------------------------------------------------------------
22+ void EarlyStopping::SetConfigParams (int frequency, int window_size,
23+ double threshold,
24+ EarlyStoppingStrategy strategy,
25+ double ema_alpha, bool enable_logging,
26+ const std::string& logger_name,
27+ int warmup_iters) {
28+ frequency_ = frequency;
29+ window_size_ = window_size;
30+ threshold_ = threshold;
31+ strategy_ = strategy;
32+ ema_alpha_ = ema_alpha;
33+ enable_logging_ = enable_logging;
34+ logger_name_ = logger_name;
35+ last_checked_iter_ = -1 ;
36+ warmup_iters_ = warmup_iters;
37+ }
38+
39+ // ---------------------------------------------------------------------------
40+ void EarlyStopping::reset () {
41+ score_history_.clear ();
42+ ema_initialized_ = false ;
43+ ema_diff_ = Eigen::VectorXd ();
44+ last_checked_iter_ = -1 ;
45+ // stop_flag_ = false;
46+ stop_flag_.store (false );
47+ }
48+
49+ // ---------------------------------------------------------------------------
50+ void EarlyStopping::update (int iteration, const ParticleSystem* p) {
51+ if (stop_flag_ || (iteration - last_checked_iter_ < frequency_)) return ;
52+
53+ if (iteration < warmup_iters_) return ;
54+
55+ if (control_shapes_.size () == 0 ) {
56+ throw std::runtime_error (" Control shapes not initialized yet." );
57+ }
58+
59+ TIME_START (" early_stopping_score" );
60+ Eigen::MatrixXd X = GetTestShapes (p);
61+ if (X.size () == 0 ) return ;
62+
63+ last_checked_iter_ = iteration;
64+ Eigen::VectorXd score = ComputeScore (X);
65+
66+ SW_DEBUG (" Early stopping score mean at optimization iteration {} = {}" , iteration, score.mean ());
67+
68+ if (!ema_initialized_) {
69+ ema_diff_ = Eigen::VectorXd::Zero (score.size ());
70+ }
71+
72+ score_history_.push_back (score);
73+ if (score_history_.size () > window_size_ + 1 ) score_history_.pop_front ();
74+
75+ if (score_history_.size () >= 2 ) {
76+ if (HasConverged ())
77+ // stop_flag_ = true;
78+ stop_flag_.store (true );
79+ }
80+ TIME_STOP (" early_stopping_score" );
81+
82+ }
83+
84+ // bool EarlyStopping::ShouldStop() const {
85+ // return stop_flag_;
86+ // }
87+
88+ // ---------------------------------------------------------------------------
89+ bool EarlyStopping::ShouldStop () const { return stop_flag_.load (); }
90+
91+ // ---------------------------------------------------------------------------
92+ Eigen::VectorXd EarlyStopping::ComputeScore (const Eigen::MatrixXd& X) {
93+ return score_func_.GetMahalanobisDistance (X);
94+ }
95+
96+ // ---------------------------------------------------------------------------
97+ Eigen::VectorXd EarlyStopping::ComputeRelativeDiff (
98+ const Eigen::VectorXd& a, const Eigen::VectorXd& b) const {
99+ const double eps = 1e-8 ;
100+ return (a - b).cwiseAbs ().array () / (b.cwiseAbs ().array () + eps);
101+ }
102+
103+ // ---------------------------------------------------------------------------
104+ bool EarlyStopping::HasConverged () const {
105+
106+ switch (strategy_) {
107+ case EarlyStoppingStrategy::RelativeDifference:
108+ return CheckRelativeDifference ();
109+ case EarlyStoppingStrategy::ExponentialMovingAverage:
110+ return CheckExponentialMovingAverage ();
111+ default :
112+ return false ;
113+ }
114+ }
115+
116+ // ---------------------------------------------------------------------------
117+ bool EarlyStopping::CheckRelativeDifference () const {
118+
119+ const Eigen::VectorXd& current = score_history_.back ();
120+ const Eigen::VectorXd& prev = score_history_[score_history_.size () - 2 ];
121+
122+ Eigen::VectorXd rel_diff = ComputeRelativeDiff (current, prev);
123+ std::vector<bool > converged (rel_diff.size ());
124+ for (int i = 0 ; i < rel_diff.size (); ++i) {
125+ converged[i] = rel_diff[i] < threshold_;
126+ // std::cout << "DEBUG | Test Shape " << i << "/" << rel_diff.size() << " rel_diff = " << rel_diff[i] << " converged = " << converged[i];
127+ }
128+
129+ if (enable_logging_)
130+ LogStatus (last_checked_iter_, current, rel_diff, converged);
131+
132+ return std::all_of (converged.begin (), converged.end (),
133+ [](bool x) { return x; });
134+ }
135+
136+ // ---------------------------------------------------------------------------
137+ bool EarlyStopping::CheckExponentialMovingAverage () const {
138+ const Eigen::VectorXd& current = score_history_.back ();
139+ const Eigen::VectorXd& prev = score_history_[score_history_.size () - 2 ];
140+
141+ Eigen::VectorXd abs_diff = (current - prev).cwiseAbs ();
142+
143+ if (!ema_initialized_) {
144+ ema_diff_ = abs_diff;
145+ ema_initialized_ = true ;
146+ } else {
147+ ema_diff_ = ema_alpha_ * abs_diff + (1.0 - ema_alpha_) * ema_diff_;
148+ }
149+
150+ std::vector<bool > converged (ema_diff_.size ());
151+ for (int i = 0 ; i < ema_diff_.size (); ++i)
152+ {
153+ converged[i] = ema_diff_[i] < threshold_;
154+ // std::cout << "DEBUG | Test Shape " << i << "/" << ema_diff_.size() << " ema_diff_ = " << ema_diff_[i] << " converged = " << converged[i];
155+
156+ }
157+
158+ if (enable_logging_)
159+ LogStatus (last_checked_iter_, current, ema_diff_, converged);
160+
161+ return std::all_of (converged.begin (), converged.end (),
162+ [](bool x) { return x; });
163+ }
164+
165+ // ---------------------------------------------------------------------------
166+ void EarlyStopping::LogStatus (
167+ int iter, const Eigen::VectorXd& score, const Eigen::VectorXd& diff,
168+ const std::vector<bool >& per_subject_convergence) const {
169+ std::cout << " [EarlyStopping" ;
170+ if (!logger_name_.empty ()) std::cout << " - " << logger_name_;
171+ std::cout << " ] Iteration " << iter << " :\n " ;
172+
173+ for (int i = 0 ; i < score.size (); ++i) {
174+ std::cout << " TestShape[" << i << " ]: Score=" << score[i]
175+ << " , Diff=" << diff[i]
176+ << " , Status= " << (per_subject_convergence[i] ? " Converged" : " Not Converged" )
177+ << " \n " ;
178+ }
179+
180+ int num_converged = std::count (per_subject_convergence.begin (),
181+ per_subject_convergence.end (), true );
182+ std::cout << " --> " << num_converged << " / " << score.size ()
183+ << " shapes converged\n\n " ;
184+ }
185+
186+ // ---------------------------------------------------------------------------
187+ bool EarlyStopping::SetControlShapes (const ParticleSystem* p) {
188+ if (control_shapes_.size () > 0 ) return true ; // already initialized return back
189+
190+ unsigned int numdomains = p->GetNumberOfDomains ();
191+ const auto domains_per_shape = p->GetDomainsPerShape ();
192+ unsigned int num_shapes = numdomains / domains_per_shape;
193+
194+ std::vector<Eigen::RowVectorXd> shape_vectors;
195+ for (size_t shape = 0 ; shape < num_shapes; ++shape) {
196+ for (int shape_dom_idx = 0 ; shape_dom_idx < domains_per_shape;
197+ ++shape_dom_idx) {
198+ auto dom = shape * domains_per_shape + shape_dom_idx;
199+
200+ if (p->GetDomainFlag (dom) == true ) {
201+ auto num_points = p->GetPositions (dom)->GetSize ();
202+ Eigen::RowVectorXd shape_vector (VDimension * num_points);
203+
204+ for (int k = 0 ; k < num_points; ++k) {
205+ // PointType pt = p->GetTransformedPosition(dom, k);
206+ PointType pt = p->GetPositions (dom)->Get (k);
207+
208+ shape_vector (VDimension * k + 0 ) = pt[0 ];
209+ shape_vector (VDimension * k + 1 ) = pt[1 ];
210+ shape_vector (VDimension * k + 2 ) = pt[2 ];
211+ }
212+
213+ shape_vectors.push_back (shape_vector);
214+ }
215+ }
216+ }
217+
218+ int num_control_shapes = static_cast <int >(shape_vectors.size ());
219+ if (num_control_shapes > 0 ) {
220+ int d = static_cast <int >(shape_vectors[0 ].cols ());
221+ control_shapes_.resize (num_control_shapes, d);
222+ for (int i = 0 ; i < num_control_shapes; ++i) {
223+ control_shapes_.row (i) = shape_vectors[i];
224+ }
225+ } else {
226+ return false ; // Forcibly turn off early stopping now if no fixed domains present
227+ // TODO: handle non-fixed domain cases when no fixed domains are present for control shapes pca initialization
228+ // Maybe, initialize with representative clusters?
229+ // control_shapes_.resize(0, 0);
230+ // throw std::runtime_error(
231+ // "Fix some domains with particles initialized to use early stopping "
232+ // "feature in fixed domains optimization");
233+ }
234+ return score_func_.SetControlShapes (control_shapes_);
235+
236+ }
237+
238+ // ---------------------------------------------------------------------------
239+ Eigen::MatrixXd EarlyStopping::GetTestShapes (const ParticleSystem* p) {
240+ unsigned int numdomains = p->GetNumberOfDomains ();
241+ const auto domains_per_shape = p->GetDomainsPerShape ();
242+ unsigned int num_shapes = numdomains / domains_per_shape;
243+ std::vector<Eigen::RowVectorXd> shape_vectors;
244+ Eigen::MatrixXd test_shapes;
245+ for (size_t shape = 0 ; shape < num_shapes; ++shape) {
246+ for (int shape_dom_idx = 0 ; shape_dom_idx < domains_per_shape;
247+ ++shape_dom_idx) {
248+ auto dom = shape * domains_per_shape + shape_dom_idx;
249+
250+ if (!p->GetDomainFlag (dom)) {
251+ auto num_points = p->GetPositions (dom)->GetSize ();
252+ Eigen::RowVectorXd shape_vector (VDimension * num_points);
253+
254+ for (int k = 0 ; k < num_points; ++k) {
255+ // PointType pt = p->GetTransformedPosition(dom, k);
256+ PointType pt = p->GetPositions (dom)->Get (k);
257+ shape_vector (3 * k + 0 ) = pt[0 ];
258+ shape_vector (3 * k + 1 ) = pt[1 ];
259+ shape_vector (3 * k + 2 ) = pt[2 ];
260+ }
261+
262+ shape_vectors.push_back (shape_vector);
263+ }
264+ }
265+ }
266+ int num_test_shapes = static_cast <int >(shape_vectors.size ());
267+ // std::cout << "DEBUG | Found " << num_test_shapes << " Test Shapes " << std::endl;
268+ if (num_test_shapes > 0 ) {
269+ int d = static_cast <int >(shape_vectors[0 ].cols ());
270+ test_shapes.resize (num_test_shapes, d);
271+ for (int i = 0 ; i < num_test_shapes; ++i) {
272+ test_shapes.row (i) = shape_vectors[i];
273+ }
274+ } else {
275+ test_shapes.resize (0 , 0 );
276+ }
277+ return test_shapes;
278+ }
279+
280+ } // namespace shapeworks
0 commit comments