Skip to content

Commit 7357643

Browse files
authored
Merge pull request #2435 from SCIInstitute/early_stop_with_mahalanobis
Add Early Stopping to Optimization backend
2 parents 4dd310f + b499109 commit 7357643

14 files changed

+701
-3
lines changed

Libs/Optimize/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ FILE(GLOB Optimize_sources
44
Container/*.cpp
55
Domain/*.cpp
66
Function/*.cpp
7+
Function/EarlyStop/*.cpp
78
Matrix/*.cpp
89
Neighborhood/*.cpp
910
Utils/*.cpp
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
3+
namespace shapeworks {
4+
5+
/// strategies for early stopping
6+
enum class EarlyStoppingStrategy {
7+
/// stop when relative improvement falls below a threshold
8+
RelativeDifference,
9+
/// stop using EMA of improvements to detect convergence
10+
ExponentialMovingAverage
11+
};
12+
13+
/// Configuration for early stopping criteria
14+
struct EarlyStoppingConfig {
15+
bool enabled = false;
16+
17+
/// how often (in iterations) to check the stopping criterion
18+
int frequency = 100;
19+
20+
/// number of past values to consider in relative difference or EMA
21+
int window_size = 5;
22+
23+
/// Threshold for stopping:
24+
/// For RelativeDifference: use ~1e-4 or smaller.
25+
/// For EMA: use ~1e-1 or smaller.
26+
double threshold = 0.0001;
27+
28+
/// strategy used for determining early stopping.
29+
EarlyStoppingStrategy strategy = EarlyStoppingStrategy::RelativeDifference;
30+
31+
/// higher value of alpha give more weight to recent iterations
32+
double ema_alpha = 0.2;
33+
34+
bool enable_logging = false;
35+
36+
std::string logger_name = "early_stopping_log_stats";
37+
38+
/// to prevent premature stoppin
39+
int warmup_iters = 1000;
40+
};
41+
42+
} // namespace shapeworks
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
#include "EarlyStopping.h"
2+
#include "MorphologicalDeviationScore.h"
3+
#include <Eigen/Dense>
4+
#include <Profiling.h>
5+
#include <Logging.h>
6+
7+
namespace shapeworks {
8+
9+
EarlyStopping::EarlyStopping()
10+
: frequency_(100),
11+
window_size_(5),
12+
threshold_(0.0001),
13+
strategy_(EarlyStoppingStrategy::RelativeDifference),
14+
ema_alpha_(0.2),
15+
enable_logging_(false),
16+
logger_name_("early_stopping"),
17+
last_checked_iter_(-1),
18+
warmup_iters_(1000),
19+
score_func_() {}
20+
21+
//---------------------------------------------------------------------------
22+
void EarlyStopping::SetConfigParams(int frequency, int window_size,
23+
double threshold,
24+
EarlyStoppingStrategy strategy,
25+
double ema_alpha, bool enable_logging,
26+
const std::string& logger_name,
27+
int warmup_iters) {
28+
frequency_ = frequency;
29+
window_size_ = window_size;
30+
threshold_ = threshold;
31+
strategy_ = strategy;
32+
ema_alpha_ = ema_alpha;
33+
enable_logging_ = enable_logging;
34+
logger_name_ = logger_name;
35+
last_checked_iter_ = -1;
36+
warmup_iters_ = warmup_iters;
37+
}
38+
39+
//---------------------------------------------------------------------------
40+
void EarlyStopping::reset() {
41+
score_history_.clear();
42+
ema_initialized_ = false;
43+
ema_diff_ = Eigen::VectorXd();
44+
last_checked_iter_ = -1;
45+
// stop_flag_ = false;
46+
stop_flag_.store(false);
47+
}
48+
49+
//---------------------------------------------------------------------------
50+
void EarlyStopping::update(int iteration, const ParticleSystem* p) {
51+
if (stop_flag_ || (iteration - last_checked_iter_ < frequency_)) return;
52+
53+
if (iteration < warmup_iters_) return;
54+
55+
if (control_shapes_.size() == 0) {
56+
throw std::runtime_error("Control shapes not initialized yet.");
57+
}
58+
59+
TIME_START("early_stopping_score");
60+
Eigen::MatrixXd X = GetTestShapes(p);
61+
if (X.size() == 0) return;
62+
63+
last_checked_iter_ = iteration;
64+
Eigen::VectorXd score = ComputeScore(X);
65+
66+
SW_DEBUG("Early stopping score mean at optimization iteration {} = {}", iteration, score.mean());
67+
68+
if (!ema_initialized_) {
69+
ema_diff_ = Eigen::VectorXd::Zero(score.size());
70+
}
71+
72+
score_history_.push_back(score);
73+
if (score_history_.size() > window_size_ + 1) score_history_.pop_front();
74+
75+
if (score_history_.size() >= 2) {
76+
if (HasConverged())
77+
// stop_flag_ = true;
78+
stop_flag_.store(true);
79+
}
80+
TIME_STOP("early_stopping_score");
81+
82+
}
83+
84+
// bool EarlyStopping::ShouldStop() const {
85+
// return stop_flag_;
86+
// }
87+
88+
//---------------------------------------------------------------------------
89+
bool EarlyStopping::ShouldStop() const { return stop_flag_.load(); }
90+
91+
//---------------------------------------------------------------------------
92+
Eigen::VectorXd EarlyStopping::ComputeScore(const Eigen::MatrixXd& X) {
93+
return score_func_.GetMahalanobisDistance(X);
94+
}
95+
96+
//---------------------------------------------------------------------------
97+
Eigen::VectorXd EarlyStopping::ComputeRelativeDiff(
98+
const Eigen::VectorXd& a, const Eigen::VectorXd& b) const {
99+
const double eps = 1e-8;
100+
return (a - b).cwiseAbs().array() / (b.cwiseAbs().array() + eps);
101+
}
102+
103+
//---------------------------------------------------------------------------
104+
bool EarlyStopping::HasConverged() const {
105+
106+
switch (strategy_) {
107+
case EarlyStoppingStrategy::RelativeDifference:
108+
return CheckRelativeDifference();
109+
case EarlyStoppingStrategy::ExponentialMovingAverage:
110+
return CheckExponentialMovingAverage();
111+
default:
112+
return false;
113+
}
114+
}
115+
116+
//---------------------------------------------------------------------------
117+
bool EarlyStopping::CheckRelativeDifference() const {
118+
119+
const Eigen::VectorXd& current = score_history_.back();
120+
const Eigen::VectorXd& prev = score_history_[score_history_.size() - 2];
121+
122+
Eigen::VectorXd rel_diff = ComputeRelativeDiff(current, prev);
123+
std::vector<bool> converged(rel_diff.size());
124+
for (int i = 0; i < rel_diff.size(); ++i) {
125+
converged[i] = rel_diff[i] < threshold_;
126+
// std::cout << "DEBUG | Test Shape " << i << "/" << rel_diff.size() << " rel_diff = " << rel_diff[i] << " converged = " << converged[i];
127+
}
128+
129+
if (enable_logging_)
130+
LogStatus(last_checked_iter_, current, rel_diff, converged);
131+
132+
return std::all_of(converged.begin(), converged.end(),
133+
[](bool x) { return x; });
134+
}
135+
136+
//---------------------------------------------------------------------------
137+
bool EarlyStopping::CheckExponentialMovingAverage() const {
138+
const Eigen::VectorXd& current = score_history_.back();
139+
const Eigen::VectorXd& prev = score_history_[score_history_.size() - 2];
140+
141+
Eigen::VectorXd abs_diff = (current - prev).cwiseAbs();
142+
143+
if (!ema_initialized_) {
144+
ema_diff_ = abs_diff;
145+
ema_initialized_ = true;
146+
} else {
147+
ema_diff_ = ema_alpha_ * abs_diff + (1.0 - ema_alpha_) * ema_diff_;
148+
}
149+
150+
std::vector<bool> converged(ema_diff_.size());
151+
for (int i = 0; i < ema_diff_.size(); ++i)
152+
{
153+
converged[i] = ema_diff_[i] < threshold_;
154+
// std::cout << "DEBUG | Test Shape " << i << "/" << ema_diff_.size() << " ema_diff_ = " << ema_diff_[i] << " converged = " << converged[i];
155+
156+
}
157+
158+
if (enable_logging_)
159+
LogStatus(last_checked_iter_, current, ema_diff_, converged);
160+
161+
return std::all_of(converged.begin(), converged.end(),
162+
[](bool x) { return x; });
163+
}
164+
165+
//---------------------------------------------------------------------------
166+
void EarlyStopping::LogStatus(
167+
int iter, const Eigen::VectorXd& score, const Eigen::VectorXd& diff,
168+
const std::vector<bool>& per_subject_convergence) const {
169+
std::cout << "[EarlyStopping";
170+
if (!logger_name_.empty()) std::cout << " - " << logger_name_;
171+
std::cout << "] Iteration " << iter << ":\n";
172+
173+
for (int i = 0; i < score.size(); ++i) {
174+
std::cout << " TestShape[" << i << "]: Score=" << score[i]
175+
<< ", Diff=" << diff[i]
176+
<< ", Status= " << (per_subject_convergence[i] ? "Converged" : "Not Converged")
177+
<< "\n";
178+
}
179+
180+
int num_converged = std::count(per_subject_convergence.begin(),
181+
per_subject_convergence.end(), true);
182+
std::cout << " --> " << num_converged << " / " << score.size()
183+
<< " shapes converged\n\n";
184+
}
185+
186+
//---------------------------------------------------------------------------
187+
bool EarlyStopping::SetControlShapes(const ParticleSystem* p) {
188+
if (control_shapes_.size() > 0) return true; // already initialized return back
189+
190+
unsigned int numdomains = p->GetNumberOfDomains();
191+
const auto domains_per_shape = p->GetDomainsPerShape();
192+
unsigned int num_shapes = numdomains / domains_per_shape;
193+
194+
std::vector<Eigen::RowVectorXd> shape_vectors;
195+
for (size_t shape = 0; shape < num_shapes; ++shape) {
196+
for (int shape_dom_idx = 0; shape_dom_idx < domains_per_shape;
197+
++shape_dom_idx) {
198+
auto dom = shape * domains_per_shape + shape_dom_idx;
199+
200+
if (p->GetDomainFlag(dom) == true) {
201+
auto num_points = p->GetPositions(dom)->GetSize();
202+
Eigen::RowVectorXd shape_vector(VDimension * num_points);
203+
204+
for (int k = 0; k < num_points; ++k) {
205+
// PointType pt = p->GetTransformedPosition(dom, k);
206+
PointType pt = p->GetPositions(dom)->Get(k);
207+
208+
shape_vector(VDimension * k + 0) = pt[0];
209+
shape_vector(VDimension * k + 1) = pt[1];
210+
shape_vector(VDimension * k + 2) = pt[2];
211+
}
212+
213+
shape_vectors.push_back(shape_vector);
214+
}
215+
}
216+
}
217+
218+
int num_control_shapes = static_cast<int>(shape_vectors.size());
219+
if (num_control_shapes > 0) {
220+
int d = static_cast<int>(shape_vectors[0].cols());
221+
control_shapes_.resize(num_control_shapes, d);
222+
for (int i = 0; i < num_control_shapes; ++i) {
223+
control_shapes_.row(i) = shape_vectors[i];
224+
}
225+
} else {
226+
return false; // Forcibly turn off early stopping now if no fixed domains present
227+
// TODO: handle non-fixed domain cases when no fixed domains are present for control shapes pca initialization
228+
// Maybe, initialize with representative clusters?
229+
// control_shapes_.resize(0, 0);
230+
// throw std::runtime_error(
231+
// "Fix some domains with particles initialized to use early stopping "
232+
// "feature in fixed domains optimization");
233+
}
234+
return score_func_.SetControlShapes(control_shapes_);
235+
236+
}
237+
238+
//---------------------------------------------------------------------------
239+
Eigen::MatrixXd EarlyStopping::GetTestShapes(const ParticleSystem* p) {
240+
unsigned int numdomains = p->GetNumberOfDomains();
241+
const auto domains_per_shape = p->GetDomainsPerShape();
242+
unsigned int num_shapes = numdomains / domains_per_shape;
243+
std::vector<Eigen::RowVectorXd> shape_vectors;
244+
Eigen::MatrixXd test_shapes;
245+
for (size_t shape = 0; shape < num_shapes; ++shape) {
246+
for (int shape_dom_idx = 0; shape_dom_idx < domains_per_shape;
247+
++shape_dom_idx) {
248+
auto dom = shape * domains_per_shape + shape_dom_idx;
249+
250+
if (!p->GetDomainFlag(dom)) {
251+
auto num_points = p->GetPositions(dom)->GetSize();
252+
Eigen::RowVectorXd shape_vector(VDimension * num_points);
253+
254+
for (int k = 0; k < num_points; ++k) {
255+
// PointType pt = p->GetTransformedPosition(dom, k);
256+
PointType pt = p->GetPositions(dom)->Get(k);
257+
shape_vector(3 * k + 0) = pt[0];
258+
shape_vector(3 * k + 1) = pt[1];
259+
shape_vector(3 * k + 2) = pt[2];
260+
}
261+
262+
shape_vectors.push_back(shape_vector);
263+
}
264+
}
265+
}
266+
int num_test_shapes = static_cast<int>(shape_vectors.size());
267+
// std::cout << "DEBUG | Found " << num_test_shapes << " Test Shapes " << std::endl;
268+
if (num_test_shapes > 0) {
269+
int d = static_cast<int>(shape_vectors[0].cols());
270+
test_shapes.resize(num_test_shapes, d);
271+
for (int i = 0; i < num_test_shapes; ++i) {
272+
test_shapes.row(i) = shape_vectors[i];
273+
}
274+
} else {
275+
test_shapes.resize(0, 0);
276+
}
277+
return test_shapes;
278+
}
279+
280+
} // namespace shapeworks
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#pragma once
2+
#include <Eigen/Dense>
3+
4+
#include "Libs/Optimize/ParticleSystem.h"
5+
#include "Libs/Optimize/EarlyStoppingConfig.h"
6+
#include "MorphologicalDeviationScore.h"
7+
8+
namespace shapeworks {
9+
10+
class EarlyStopping {
11+
public:
12+
typedef typename ParticleSystem::PointType PointType;
13+
constexpr static int VDimension = 3;
14+
EarlyStopping();
15+
void SetConfigParams(int frequency,
16+
int window_size,
17+
double threshold,
18+
EarlyStoppingStrategy strategy = EarlyStoppingStrategy::RelativeDifference,
19+
double ema_alpha = 0.2,
20+
bool enable_logging = false,
21+
const std::string& logger_name = "",
22+
int warmup_iters = 1000);
23+
24+
void reset();
25+
void update(int iteration, const ParticleSystem* p);
26+
bool ShouldStop() const;
27+
bool SetControlShapes(const ParticleSystem* p);
28+
Eigen::MatrixXd GetTestShapes(const ParticleSystem* p);
29+
30+
private:
31+
std::deque<Eigen::VectorXd> score_history_;
32+
int frequency_, window_size_;
33+
double threshold_, ema_alpha_;
34+
int last_checked_iter_;
35+
int warmup_iters_;
36+
// bool stop_flag_;
37+
mutable std::atomic<bool> stop_flag_{false};
38+
bool enable_logging_;
39+
std::string logger_name_;
40+
Eigen::MatrixXd control_shapes_;
41+
MorphologicalDeviationScore score_func_;
42+
EarlyStoppingStrategy strategy_;
43+
mutable Eigen::VectorXd ema_diff_;
44+
mutable bool ema_initialized_ = false;
45+
46+
Eigen::VectorXd ComputeScore(const Eigen::MatrixXd& X) ;
47+
Eigen::VectorXd ComputeRelativeDiff(const Eigen::VectorXd& a, const Eigen::VectorXd& b) const;
48+
bool HasConverged() const;
49+
bool CheckRelativeDifference() const;
50+
bool CheckExponentialMovingAverage() const;
51+
void LogStatus(int iter,
52+
const Eigen::VectorXd& current_score,
53+
const Eigen::VectorXd& diff,
54+
const std::vector<bool>& per_subject_convergence) const;
55+
};
56+
57+
} // namespace shapeworks

0 commit comments

Comments
 (0)