File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed
Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import arviz as az
1514
1615
17- def fit (method : str , ** kwargs ) -> az . InferenceData :
16+ def fit (method , ** kwargs ):
1817 """
1918 Fit a model with an inference algorithm
2019
Original file line number Diff line number Diff line change 2121from collections .abc import Callable , Iterator
2222from dataclasses import asdict , dataclass , field , replace
2323from enum import Enum , auto
24+ from importlib .util import find_spec
2425from typing import Literal , TypeAlias
2526
2627import arviz as az
28+ import blackjax
2729import filelock
2830import jax
2931import numpy as np
@@ -1734,8 +1736,8 @@ def fit_pathfinder(
17341736 )
17351737 pathfinder_samples = mp_result .samples
17361738 elif inference_backend == "blackjax" :
1737- import blackjax
1738-
1739+ if find_spec ( " blackjax" ) is None :
1740+ raise RuntimeError ( "Need BlackJAX to use `pathfinder`" )
17391741 if version .parse (blackjax .__version__ ).major < 1 :
17401742 raise ImportError ("fit_pathfinder requires blackjax 1.0 or above" )
17411743
You can’t perform that action at this time.
0 commit comments