Skip to content

Commit 4c9c867

Browse files
authored
Revert "Minor fix of blackjax import in fit_pathfinder function (pymc-devs#443)"
This reverts commit 1fff560.
1 parent 1fff560 commit 4c9c867

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pymc_extras/inference/fit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import arviz as az
1514

1615

17-
def fit(method: str, **kwargs) -> az.InferenceData:
16+
def fit(method, **kwargs):
1817
"""
1918
Fit a model with an inference algorithm
2019

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
from collections.abc import Callable, Iterator
2222
from dataclasses import asdict, dataclass, field, replace
2323
from enum import Enum, auto
24+
from importlib.util import find_spec
2425
from typing import Literal, TypeAlias
2526

2627
import arviz as az
28+
import blackjax
2729
import filelock
2830
import jax
2931
import numpy as np
@@ -1734,8 +1736,8 @@ def fit_pathfinder(
17341736
)
17351737
pathfinder_samples = mp_result.samples
17361738
elif inference_backend == "blackjax":
1737-
import blackjax
1738-
1739+
if find_spec("blackjax") is None:
1740+
raise RuntimeError("Need BlackJAX to use `pathfinder`")
17391741
if version.parse(blackjax.__version__).major < 1:
17401742
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
17411743

0 commit comments

Comments
 (0)