We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bfd8365 commit cff8c6cCopy full SHA for cff8c6c
pymc_extras/inference/find_map.py
@@ -3,7 +3,6 @@
3
from collections.abc import Callable
4
from typing import Literal, cast, get_args
5
6
-import jax
7
import numpy as np
8
import pymc as pm
9
import pytensor
@@ -138,6 +137,13 @@ def _compile_grad_and_hess_to_jax(
138
137
f_hessp: Callable | None
139
The compiled hessian-vector product function, or None if use_hessp is False.
140
"""
+ try:
141
+ import jax
142
+ except ImportError:
143
+ raise ImportError(
144
+ "You don't have jax installed -- it is required if gradient_backend='jax'"
145
+ )
146
+
147
f_hess = None
148
f_hessp = None
149
0 commit comments