Skip to content

Commit 85996a1

Browse files
author
Goose
committed
correct type annotations related to jaxified logp func
1 parent 2855587 commit 85996a1

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

pymc/sampling/jax.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from arviz.data.base import make_attrs
3030
from jax.lax import scan
31+
from numpy.typing import ArrayLike
3132
from pytensor.compile import SharedVariable, Supervisor, mode
3233
from pytensor.graph.basic import graph_inputs
3334
from pytensor.graph.fg import FunctionGraph
@@ -121,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
121122
def get_jaxified_graph(
122123
inputs: list[TensorVariable] | None = None,
123124
outputs: list[TensorVariable] | None = None,
124-
) -> list[TensorVariable]:
125+
) -> Callable[[list[TensorVariable]], list[TensorVariable]]:
125126
"""Compile a PyTensor graph into an optimized JAX function."""
126127
graph = _replace_shared_variables(outputs) if outputs is not None else None
127128

@@ -144,15 +145,13 @@ def get_jaxified_graph(
144145
return jax_funcify(fgraph)
145146

146147

147-
def get_jaxified_logp(
148-
model: Model, negative_logp=True
149-
) -> Callable[[Sequence[np.ndarray]], np.ndarray]:
148+
def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]:
150149
model_logp = model.logp()
151150
if not negative_logp:
152151
model_logp = -model_logp
153152
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
154153

155-
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray:
154+
def logp_fn_wrap(x: ArrayLike) -> jax.Array:
156155
return logp_fn(*x)[0]
157156

158157
return logp_fn_wrap
@@ -213,7 +212,7 @@ def _get_batched_jittered_initial_points(
213212
chains: int,
214213
initvals: StartDict | Sequence[StartDict | None] | None,
215214
random_seed: RandomSeed,
216-
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
215+
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
217216
jitter: bool = True,
218217
jitter_max_retries: int = 10,
219218
) -> np.ndarray | list[np.ndarray]:
@@ -235,7 +234,7 @@ def _get_batched_jittered_initial_points(
235234

236235
else:
237236

238-
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray:
237+
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array:
239238
"""Wrap logp_fn to conform to _init_jitter logic.
240239
241240
Wraps jaxified logp function to accept a dict of

pymc/sampling/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,8 +1353,8 @@ def _init_jitter(
13531353
Whether to apply jitter or not.
13541354
jitter_max_retries : int
13551355
Maximum number of repeated attempts at initializing values (per chain).
1356-
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] | None
1357-
Jaxified logp function that takes the output of the initial point functions as input.
1356+
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray | jax.Array] | None
1357+
logp function that takes the output of initial point functions as input.
13581358
If None, will use the results of model.compile_logp().
13591359
13601360
Returns

0 commit comments

Comments
 (0)