Skip to content

Commit 095cdf6

Browse files
authored
Increase tolerance for jaxmd test (#1326)
* Increase tolerance for jaxmd test * Make tolerance for jaxmd test backend-dependent * Tweak jaxmd tolerance with TPU backend
1 parent dfa7e7b commit 095cdf6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

test/jaxmd.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,16 @@ def forward(
120120
# TODO: This is horribly slow for reasons which are unknown.
121121
self.mlirad_fwd = False
122122

123-
self.tol = 5e-4
123+
if jax.default_backend() == "tpu":
124+
self.tol = 5e-3
124125

125-
# GPU CI reverse mode needs loose, merits future investigation
126-
self.tol = 1e-2
126+
elif jax.default_backend() == "gpu":
127+
# GPU CI reverse mode needs loose, merits future investigation
128+
self.tol = 1e-2
127129

128-
self.tol = 0.04
130+
else:
131+
# CPU backend needs loose tolerance, see #1289
132+
self.tol = 0.07
129133

130134

131135
if __name__ == "__main__":

0 commit comments

Comments
 (0)