File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change 11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import arviz as az
15
14
16
15
17
- def fit (method : str , ** kwargs ) -> az . InferenceData :
16
+ def fit (method , ** kwargs ):
18
17
"""
19
18
Fit a model with an inference algorithm
20
19
Original file line number Diff line number Diff line change 21
21
from collections .abc import Callable , Iterator
22
22
from dataclasses import asdict , dataclass , field , replace
23
23
from enum import Enum , auto
24
+ from importlib .util import find_spec
24
25
from typing import Literal , TypeAlias
25
26
26
27
import arviz as az
28
+ import blackjax
27
29
import filelock
28
30
import jax
29
31
import numpy as np
@@ -1734,8 +1736,8 @@ def fit_pathfinder(
1734
1736
)
1735
1737
pathfinder_samples = mp_result .samples
1736
1738
elif inference_backend == "blackjax" :
1737
- import blackjax
1738
-
1739
+ if find_spec ( " blackjax" ) is None :
1740
+ raise RuntimeError ( "Need BlackJAX to use `pathfinder`" )
1739
1741
if version .parse (blackjax .__version__ ).major < 1 :
1740
1742
raise ImportError ("fit_pathfinder requires blackjax 1.0 or above" )
1741
1743
You can’t perform that action at this time.
0 commit comments