Skip to content

Commit 080a903

Browse files
core logic for morphological deviation score | criteria for early stopping
1 parent 56e1c24 commit 080a903

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include "MorphologicalDeviationScore.h"
2+
#include <Logging.h>
3+
4+
namespace shapeworks {
5+
6+
MorphologicalDeviationScore::MorphologicalDeviationScore() = default;
7+
8+
//---------------------------------------------------------------------------
9+
bool MorphologicalDeviationScore::SetControlShapes(const Eigen::MatrixXd& X) {
10+
try {
11+
bool ppca_status = FitPPCA(X);
12+
precision_matrix_ = ComputePrecisionMatrix();
13+
is_fitted_ = ppca_status;
14+
} catch (std::exception& e) {
15+
SW_ERROR("Exception in setting control shapes for early stopping {}",
16+
e.what());
17+
return false;
18+
}
19+
return is_fitted_;
20+
}
21+
22+
//---------------------------------------------------------------------------
23+
bool MorphologicalDeviationScore::FitPPCA(const Eigen::MatrixXd& X) {
24+
const int n = X.rows(); // n_samples
25+
const int d = X.cols(); // n_features
26+
27+
mean_ = X.colwise().mean(); // (1 x d)
28+
Eigen::MatrixXd X_c = X.rowwise() - mean_; // (n x d)
29+
30+
try {
31+
Eigen::JacobiSVD<Eigen::MatrixXd> svd(
32+
X_c, Eigen::ComputeThinU | Eigen::ComputeThinV);
33+
Eigen::VectorXd s = svd.singularValues(); // (r,)
34+
Eigen::MatrixXd V = svd.matrixV(); // (d x r)
35+
Eigen::VectorXd eigvals =
36+
(s.array().square()) / std::max(1, n - 1); // (r,)
37+
38+
double total_var = eigvals.sum();
39+
if (total_var <= 0.0) {
40+
throw std::runtime_error(
41+
"Total variance is non-positive. Cannot fit PCA on control shapes");
42+
}
43+
44+
Eigen::VectorXd cum_var = eigvals;
45+
for (int i = 1; i < cum_var.size(); ++i) {
46+
cum_var(i) += cum_var(i - 1);
47+
}
48+
cum_var /= total_var;
49+
50+
int q = 0;
51+
for (int i = 0; i < cum_var.size(); ++i) {
52+
if (cum_var(i) > 0.95) {
53+
q = i + 1;
54+
break;
55+
}
56+
}
57+
if (q == 0) q = cum_var.size();
58+
59+
eigvals_ = eigvals.head(q); // (q,)
60+
components_ = V.leftCols(q); // (d x q)
61+
62+
if (q < d) {
63+
double rem_sum = eigvals.tail(eigvals.size() - q).sum();
64+
noise_variance_ = rem_sum / (d - q);
65+
} else {
66+
noise_variance_ = 0.0;
67+
}
68+
return true;
69+
70+
} catch (std::exception& e) {
71+
SW_ERROR(
72+
"Exception in SVD computation for early stopping score function {}",
73+
e.what());
74+
return false;
75+
}
76+
}
77+
78+
//---------------------------------------------------------------------------
79+
Eigen::MatrixXd MorphologicalDeviationScore::ComputeCovarianceMatrix() {
80+
const int d = components_.rows();
81+
const int q = components_.cols();
82+
83+
if (q == 0) {
84+
return noise_variance_ * Eigen::MatrixXd::Identity(d, d);
85+
}
86+
87+
Eigen::MatrixXd diag_lambda = eigvals_.asDiagonal(); // (q x q)
88+
Eigen::MatrixXd cov =
89+
components_ *
90+
(diag_lambda - noise_variance_ * Eigen::MatrixXd::Identity(q, q)) *
91+
components_.transpose() +
92+
noise_variance_ * Eigen::MatrixXd::Identity(d, d); // (d x d)
93+
return cov;
94+
}
95+
96+
//---------------------------------------------------------------------------
97+
Eigen::MatrixXd MorphologicalDeviationScore::ComputePrecisionMatrix() {
98+
try {
99+
const int d = components_.rows();
100+
const int q = components_.cols();
101+
102+
if (q == 0) {
103+
if (noise_variance_ <= 0.0)
104+
throw std::runtime_error(
105+
"Noise variance is zero; precision undefined.");
106+
return (1.0 / noise_variance_) * Eigen::MatrixXd::Identity(d, d);
107+
}
108+
109+
Eigen::MatrixXd P = components_ * components_.transpose(); // (d x d)
110+
Eigen::MatrixXd term_principal = components_ *
111+
eigvals_.cwiseInverse().asDiagonal() *
112+
components_.transpose(); // (d x d)
113+
114+
if (noise_variance_ <= 0.0) {
115+
if (q != d)
116+
throw std::runtime_error(
117+
"Noise variance is zero and q < d; covariance is singular.");
118+
return term_principal;
119+
}
120+
121+
Eigen::MatrixXd precision =
122+
term_principal + (1.0 / noise_variance_) *
123+
(Eigen::MatrixXd::Identity(d, d) - P); // (d x d)
124+
return precision;
125+
126+
} catch (std::exception& e) {
127+
SW_ERROR(
128+
"Exception in computation of precision matrix in early stopping score "
129+
"function {}",
130+
e.what());
131+
return Eigen::MatrixXd();
132+
}
133+
}
134+
135+
//---------------------------------------------------------------------------
136+
Eigen::VectorXd MorphologicalDeviationScore::GetMahalanobisDistance(
137+
const Eigen::MatrixXd& X) {
138+
try {
139+
if (!is_fitted_) {
140+
throw std::runtime_error(
141+
"PCA model has not been fitted yet on control shapes.");
142+
}
143+
const int n = X.rows();
144+
Eigen::MatrixXd X_c = X.rowwise() - mean_;
145+
Eigen::VectorXd dist(n);
146+
for (int i = 0; i < n; ++i) {
147+
Eigen::RowVectorXd xi = X_c.row(i);
148+
dist(i) = std::sqrt(
149+
std::max(0.0, (xi * precision_matrix_ * xi.transpose())(0)));
150+
}
151+
return dist;
152+
} catch (std::exception& e) {
153+
SW_ERROR(
154+
"Exception in computing Mahalanobis distance for early stopping score "
155+
"function {}",
156+
e.what());
157+
return Eigen::VectorXd();
158+
}
159+
}
160+
161+
//---------------------------------------------------------------------------
162+
// double MorphologicalDeviationScore::GetDeviationScore(const Eigen::MatrixXd&
163+
// X_test) {
164+
// if (!is_fitted_) {
165+
// throw std::runtime_error("PPCA model has not been fitted. Call
166+
// SetControlShapes() first.");
167+
// }
168+
// Eigen::VectorXd dists = GetMahalanobisDistance(X_test);
169+
// return dists(0);
170+
// }
171+
172+
} // namespace shapeworks
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <Eigen/Dense>
4+
5+
namespace shapeworks {
6+
class MorphologicalDeviationScore {
7+
public:
8+
MorphologicalDeviationScore();
9+
// Fit PPCA model on control shapes
10+
bool SetControlShapes(const Eigen::MatrixXd& X); // (n_samples x n_features)
11+
/// Get Mahalanobis-based deviation score for test samples (non-fixed shapes/domains)
12+
Eigen::VectorXd GetMahalanobisDistance(const Eigen::MatrixXd& X); // (n,)
13+
14+
private:
15+
/// Flag to ensure control shapes are set and PCA model is in place
16+
bool is_fitted_ = false;
17+
18+
// Fitted model parameters
19+
Eigen::RowVectorXd mean_; // (1 x d)
20+
Eigen::MatrixXd components_; // (d x q)
21+
Eigen::VectorXd eigvals_; // (q,)
22+
double noise_variance_;
23+
// Derived matrices
24+
Eigen::MatrixXd precision_matrix_; // (d x d)
25+
// Helper functions
26+
bool FitPPCA(const Eigen::MatrixXd& X);
27+
Eigen::MatrixXd ComputeCovarianceMatrix();
28+
Eigen::MatrixXd ComputePrecisionMatrix();
29+
};
30+
} // namespace shapeworks

0 commit comments

Comments
 (0)