Skip to content

Commit cff8c6c

Browse files
Delay jax import
1 parent bfd8365 commit cff8c6c

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pymc_extras/inference/find_map.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Callable
44
from typing import Literal, cast, get_args
55

6-
import jax
76
import numpy as np
87
import pymc as pm
98
import pytensor
@@ -138,6 +137,13 @@ def _compile_grad_and_hess_to_jax(
138137
f_hessp: Callable | None
139138
The compiled hessian-vector product function, or None if use_hessp is False.
140139
"""
140+
try:
141+
import jax
142+
except ImportError:
143+
raise ImportError(
144+
"You don't have jax installed -- it is required if gradient_backend='jax'"
145+
)
146+
141147
f_hess = None
142148
f_hessp = None
143149

0 commit comments

Comments
 (0)