Skip to content

Commit 7356028

Browse files
committed
feat: Do not report invalid gradients for transform adapt
1 parent 1db17a0 commit 7356028

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/transform_adapt_strategy.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
114114
let point = end.point();
115115
let energy_error = point.energy_error();
116116
if energy_error.abs() < self.max_energy_error {
117+
if !math.array_all_finite(point.position()) {
118+
return;
119+
}
120+
if !math.array_all_finite(point.gradient()) {
121+
return;
122+
}
117123
self.draws.push(math.copy_array(point.position()));
118124
self.grads.push(math.copy_array(point.gradient()));
119125
}
@@ -125,6 +131,12 @@ impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
125131
let point = state.point();
126132
let energy_error = point.energy_error();
127133
if energy_error.abs() < self.max_energy_error {
134+
if !math.array_all_finite(point.position()) {
135+
return;
136+
}
137+
if !math.array_all_finite(point.gradient()) {
138+
return;
139+
}
128140
self.draws.push(math.copy_array(point.position()));
129141
self.grads.push(math.copy_array(point.gradient()));
130142
}

0 commit comments

Comments
 (0)