File tree Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change 11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import arviz as az
14
15
15
16
16
- def fit (method , ** kwargs ):
17
+ def fit (method : str , ** kwargs ) -> az . InferenceData :
17
18
"""
18
19
Fit a model with an inference algorithm
19
20
Original file line number Diff line number Diff line change 21
21
from collections .abc import Callable , Iterator
22
22
from dataclasses import asdict , dataclass , field , replace
23
23
from enum import Enum , auto
24
- from importlib .util import find_spec
25
24
from typing import Literal , TypeAlias
26
25
27
26
import arviz as az
28
- import blackjax
29
27
import filelock
30
28
import jax
31
29
import numpy as np
@@ -1736,8 +1734,8 @@ def fit_pathfinder(
1736
1734
)
1737
1735
pathfinder_samples = mp_result .samples
1738
1736
elif inference_backend == "blackjax" :
1739
- if find_spec ( " blackjax" ) is None :
1740
- raise RuntimeError ( "Need BlackJAX to use `pathfinder`" )
1737
+ import blackjax
1738
+
1741
1739
if version .parse (blackjax .__version__ ).major < 1 :
1742
1740
raise ImportError ("fit_pathfinder requires blackjax 1.0 or above" )
1743
1741
You can’t perform that action at this time.
0 commit comments