Skip to content

Commit f3a97c2

Browse files
authored
add nan safe log&divide (#2611)
1 parent 851d389 commit f3a97c2

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

python/sdist/amici/jax.template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from interpax import interp1d
44
from pathlib import Path
55

6-
from amici.jax.model import JAXModel
6+
from amici.jax.model import JAXModel, safe_log, safe_div
77

88

99
class JAXModel_TPL_MODEL_NAME(JAXModel):

python/sdist/amici/jax/model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,39 @@ def simulate_condition(
559559
stats_dyn=stats_dyn,
560560
stats_posteq=stats_posteq,
561561
)
562+
563+
564+
def safe_log(x: jnp.float_) -> jnp.float_:
565+
"""
566+
Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0.
567+
568+
:param x:
569+
input
570+
:return:
571+
logarithm of x
572+
"""
573+
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
574+
# against nans in forward & backward passes
575+
safe_x = jnp.where(
576+
x > jnp.finfo(jnp.float_).eps, x, jnp.finfo(jnp.float_).eps
577+
)
578+
return jnp.where(
579+
x > 0, jnp.log(safe_x), jnp.log(jnp.finfo(jnp.float_).eps)
580+
)
581+
582+
583+
def safe_div(x: jnp.float_, y: jnp.float_) -> jnp.float_:
584+
"""
585+
Safe division that returns `x/jnp.finfo(jnp.float_).eps` for `y == 0`.
586+
587+
:param x:
588+
numerator
589+
:param y:
590+
denominator
591+
:return:
592+
x / y
593+
"""
594+
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
595+
# against nans in forward & backward passes
596+
safe_y = jnp.where(y != 0, y, jnp.finfo(jnp.float_).eps)
597+
return jnp.where(y != 0, x / safe_y, x / jnp.finfo(jnp.float_).eps)

python/sdist/amici/jaxcodeprinter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ def _print_AmiciSpline(self, expr: sp.Expr) -> str:
2727
# FIXME: untested, where are spline nodes coming from anyways?
2828
return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")'
2929

30+
def _print_log(self, expr: sp.Expr) -> str:
31+
return f"safe_log({self.doprint(expr.args[0])})"
32+
33+
def _print_Mul(self, expr: sp.Expr) -> str:
34+
numer, denom = expr.as_numer_denom()
35+
if denom == 1:
36+
return super()._print_Mul(expr)
37+
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"
38+
3039
def _get_sym_lines(
3140
self,
3241
symbols: sp.Matrix | Iterable[str],

tests/benchmark-models/test_petab_benchmark.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,8 @@ def test_jax_llh(benchmark_problem):
299299

300300
np.random.seed(cur_settings.rng_seed)
301301

302-
problems_for_gradient_check_jax = list(
303-
set(problems_for_gradient_check) - {"Laske_PLOSComputBiol2019"}
304-
# Laske has nan values in gradient due to nan values in observables that are not used in the likelihood
305-
# but are problematic during backpropagation
306-
)
307-
308302
problem_parameters = None
309-
if problem_id in problems_for_gradient_check_jax:
303+
if problem_id in problems_for_gradient_check:
310304
point = petab_problem.x_nominal_free_scaled
311305
for _ in range(20):
312306
amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
@@ -361,14 +355,14 @@ def test_jax_llh(benchmark_problem):
361355
err_msg=f"LLH mismatch for {problem_id}",
362356
)
363357

364-
if problem_id in problems_for_gradient_check_jax:
358+
if problem_id in problems_for_gradient_check:
365359
sllh_amici = r_amici[SLLH]
366360
np.testing.assert_allclose(
367361
sllh_jax.parameters,
368362
np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]),
369363
rtol=1e-2,
370364
atol=1e-2,
371-
err_msg=f"SLLH mismatch for {problem_id}",
365+
err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}",
372366
)
373367

374368

0 commit comments

Comments
 (0)