2828
2929from arviz .data .base import make_attrs
3030from jax .lax import scan
31+ from numpy .typing import ArrayLike
3132from pytensor .compile import SharedVariable , Supervisor , mode
3233from pytensor .graph .basic import graph_inputs
3334from pytensor .graph .fg import FunctionGraph
@@ -121,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
121122def get_jaxified_graph (
122123 inputs : list [TensorVariable ] | None = None ,
123124 outputs : list [TensorVariable ] | None = None ,
124- ) -> list [TensorVariable ]:
125+ ) -> Callable [[ list [TensorVariable ]], list [ TensorVariable ] ]:
125126 """Compile a PyTensor graph into an optimized JAX function."""
126127 graph = _replace_shared_variables (outputs ) if outputs is not None else None
127128
@@ -144,15 +145,13 @@ def get_jaxified_graph(
144145 return jax_funcify (fgraph )
145146
146147
147- def get_jaxified_logp (
148- model : Model , negative_logp = True
149- ) -> Callable [[Sequence [np .ndarray ]], np .ndarray ]:
148+ def get_jaxified_logp (model : Model , negative_logp : bool = True ) -> Callable [[ArrayLike ], jax .Array ]:
150149 model_logp = model .logp ()
151150 if not negative_logp :
152151 model_logp = - model_logp
153152 logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model_logp ])
154153
155- def logp_fn_wrap (x : Sequence [ np . ndarray ] ) -> np . ndarray :
154+ def logp_fn_wrap (x : ArrayLike ) -> jax . Array :
156155 return logp_fn (* x )[0 ]
157156
158157 return logp_fn_wrap
@@ -213,7 +212,7 @@ def _get_batched_jittered_initial_points(
213212 chains : int ,
214213 initvals : StartDict | Sequence [StartDict | None ] | None ,
215214 random_seed : RandomSeed ,
216- logp_fn : Callable [[Sequence [ np . ndarray ]], np . ndarray ] | None = None ,
215+ logp_fn : Callable [[ArrayLike ], jax . Array ] | None = None ,
217216 jitter : bool = True ,
218217 jitter_max_retries : int = 10 ,
219218) -> np .ndarray | list [np .ndarray ]:
@@ -235,7 +234,7 @@ def _get_batched_jittered_initial_points(
235234
236235 else :
237236
238- def eval_logp_initial_point (point : dict [str , np .ndarray ]) -> np . ndarray :
237+ def eval_logp_initial_point (point : dict [str , np .ndarray ]) -> jax . Array :
239238 """Wrap logp_fn to conform to _init_jitter logic.
240239
241240 Wraps jaxified logp function to accept a dict of
0 commit comments