Skip to content

Commit 5014cc9

Browse files
authored
Merge pull request #2449 from SCIInstitute/dt_snapping
Fix #2448 - Some distance transforms cause excessive iterations when snapping to surface
2 parents 1f2664f + 042d3ab commit 5014cc9

File tree

2 files changed

+72
-29
lines changed

2 files changed

+72
-29
lines changed

Libs/Optimize/Domain/ImplicitSurfaceDomain.h

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,43 +51,89 @@ class ImplicitSurfaceDomain : public ImageDomainWithCurvature<T> {
5151
// guarantee the point starts in the correct image domain.
5252
bool flag = Superclass::ApplyConstraints(p);
5353

54+
const T epsilon = m_Tolerance * 0.001;
55+
T value = this->Sample(p);
56+
57+
// Early exit if already close enough
58+
if (fabs(value) <= m_Tolerance) {
59+
return flag;
60+
}
61+
62+
const unsigned int MAX_ITERATIONS = 50;
5463
unsigned int k = 0;
55-
double mult = 1.0;
64+
T prev_value = value;
5665

57-
const T epsilon = m_Tolerance * 0.001;
58-
T f = this->Sample(p);
59-
60-
T gradmag = 1.0;
61-
while (fabs(f) > (m_Tolerance * mult) || gradmag < epsilon)
62-
// while ( fabs(f) > m_Tolerance || gradmag < epsilon)
63-
{
64-
PointType p_old = p;
65-
// vnl_vector_fixed<T, DIMENSION> grad = -this->SampleGradientAtPoint(p);
66-
vnl_vector_fixed<T, DIMENSION> gradf = this->SampleGradientAtPoint(p, idx);
66+
while (k < MAX_ITERATIONS) {
67+
// Sample gradient
68+
vnl_vector_fixed<float, DIMENSION> gradf = this->SampleGradientAtPoint(p, idx);
6769
vnl_vector_fixed<double, DIMENSION> grad;
6870
grad[0] = double(gradf[0]);
6971
grad[1] = double(gradf[1]);
7072
grad[2] = double(gradf[2]);
7173

72-
gradmag = grad.magnitude();
73-
// vnl_vector_fixed<T, DIMENSION> vec = grad * (f / (gradmag + epsilon));
74-
vnl_vector_fixed<double, DIMENSION> vec = grad * (double(f) / (gradmag + double(epsilon)));
74+
double gradient_magnitude = grad.magnitude();
75+
76+
// If gradient is too small, we're stuck
77+
if (gradient_magnitude < epsilon) {
78+
break;
79+
}
80+
81+
// Normalize gradient to get direction
82+
vnl_vector_fixed<double, DIMENSION> direction = grad / gradient_magnitude;
83+
84+
// For smoothed distance fields:
85+
// - Neither value nor gradient_magnitude are exactly correct
86+
// - But their RATIO is approximately correct (within ~20%)
87+
// - Use damping to handle the residual error
88+
89+
const double DAMPING = 0.7; // Conservative for ~20% overshoot
90+
double step_size = (double(value) / gradient_magnitude) * DAMPING;
91+
92+
// Safety check: if gradient is very weak, cap step by distance value
93+
// This handles pathological cases (critical points, field edges)
94+
if (gradient_magnitude < 0.5) {
95+
double max_step = fabs(double(value)) * DAMPING;
96+
if (fabs(step_size) > max_step) {
97+
step_size = (step_size > 0 ? max_step : -max_step);
98+
}
99+
}
75100

76101
for (unsigned int i = 0; i < DIMENSION; i++) {
77-
p[i] -= vec[i];
102+
p[i] -= step_size * direction[i];
78103
}
79104

80-
f = this->Sample(p);
105+
// Re-evaluate
106+
prev_value = value;
107+
value = this->Sample(p);
81108

82-
// Raise the tolerance if we have done too many iterations.
83-
k++;
84-
if (k > 10000) {
85-
mult *= 2.0;
86-
k = 0;
109+
// Check convergence
110+
if (fabs(value) <= m_Tolerance) {
111+
break; // Success!
112+
}
113+
114+
// Check if we're oscillating or making no progress
115+
if (k > 5) {
116+
double value_change = fabs(value - prev_value);
117+
if (value_change < m_Tolerance * 0.1) {
118+
// Making very little progress, close enough
119+
break;
120+
}
121+
122+
// Check for oscillation (sign keeps flipping)
123+
if ((value > 0 && prev_value < 0) || (value < 0 && prev_value > 0)) {
124+
// We're bouncing across the surface
125+
// If we're close enough, accept it
126+
if (fabs(value) < m_Tolerance * 2.0) {
127+
break;
128+
}
129+
}
87130
}
88-
} // end while
131+
132+
k++;
133+
}
134+
89135
return flag;
90-
};
136+
}
91137

92138
inline PointType UpdateParticlePosition(const PointType& point, int idx,
93139
vnl_vector_fixed<double, DIMENSION>& update) const override {
@@ -102,7 +148,6 @@ class ImplicitSurfaceDomain : public ImageDomainWithCurvature<T> {
102148
return newpoint;
103149
}
104150

105-
106151
/** Get any valid point on the domain. This is used to place the first particle. */
107152
PointType GetZeroCrossingPoint() const override {
108153
PointType p;
@@ -111,7 +156,7 @@ class ImplicitSurfaceDomain : public ImageDomainWithCurvature<T> {
111156
return p;
112157
}
113158

114-
ImplicitSurfaceDomain() : m_Tolerance(1.0e-4) { }
159+
ImplicitSurfaceDomain() : m_Tolerance(1.0e-4) {}
115160
void PrintSelf(std::ostream& os, itk::Indent indent) const {
116161
Superclass::PrintSelf(os, indent);
117162
os << indent << "m_Tolerance = " << m_Tolerance << std::endl;
@@ -120,8 +165,6 @@ class ImplicitSurfaceDomain : public ImageDomainWithCurvature<T> {
120165

121166
private:
122167
T m_Tolerance;
123-
124-
125168
};
126169

127170
} // end namespace shapeworks

Testing/OptimizeTests/OptimizeTests.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ TEST(OptimizeTests, procrustes_scale_only_test) {
483483
std::cerr << "Eigenvalue " << i << " : " << values[i] << "\n";
484484
}
485485
ASSERT_GT(values[values.size() - 1], 275.0);
486-
ASSERT_LT(values[values.size() - 1], 380.0);
486+
ASSERT_LT(values[values.size() - 1], 385.0);
487487
}
488488

489489
// TODO Move this to mesh tests?
@@ -604,7 +604,7 @@ TEST(OptimizeTests, multi_domain_constraint) {
604604
stats.compute_modes();
605605
stats.principal_component_projections();
606606

607-
bool good = check_constraint_violations(app, 7.5e-1);
607+
bool good = check_constraint_violations(app, 9.5e-1);
608608

609609
ASSERT_TRUE(good);
610610
}

0 commit comments

Comments
 (0)