Skip to content

Commit 5c3b295

Browse files
committed
commit
1 parent 6be8577 commit 5c3b295

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pymc_experimental/inference/pathfinder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,18 @@ def fit_pathfinder(
9595
"""
9696
# Temporarily helper
9797
if version.parse(blackjax.__version__).major < 1:
98+
# test
9899
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
99-
100100
model = modelcontext(model)
101101

102-
ip = model.initial_point()
102+
ip = model.initial_point()
103103
ip_map = DictToArrayBijection.map(ip)
104104

105105
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
106-
ip, (model.logp(),), model.value_vars, ()
106+
ip,
107+
(model.logp(),),
108+
model.value_vars,
109+
(),
107110
)
108111

109112
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)

0 commit comments

Comments
 (0)