Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,11 @@ def create_partial_observed_rv(
if can_rewrite:
masked_rv = rv[mask]
fgraph = FunctionGraph(outputs=[masked_rv], clone=False, features=[ShapeFeature()])
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
unobserved_rv = local_subtensor_rv_lift.transform(fgraph, masked_rv.owner)[masked_rv]

antimasked_rv = rv[antimask]
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False, features=[ShapeFeature()])
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
observed_rv = local_subtensor_rv_lift.transform(fgraph, antimasked_rv.owner)[antimasked_rv]

# Make a clone of the observedRV, with a distinct rng so that observed and
# unobserved are never treated as equivalent (and mergeable) nodes by pytensor.
Expand Down
32 changes: 18 additions & 14 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,20 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
if not all(params.type.broadcastable):
return None

# Check whether axis covers all dimensions
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None
if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
# Check whether axis covers all dimensions
axis = tuple(sorted(node.op.axis))
if axis != tuple(range(base_var.ndim)):
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_max: Max
if base_var.type.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
measurable_max = MeasurableMaxDiscrete(axis)
else:
measurable_max = MeasurableMax(list(axis))
measurable_max = MeasurableMax(axis)

max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs
Expand Down Expand Up @@ -206,21 +208,23 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVa
if not all(params.type.broadcastable):
return None

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None
if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
# Check whether axis is supported or not
axis = tuple(sorted(node.op.axis))
if axis != tuple(range(base_var.ndim)):
return None

if not rv_map_feature.request_measurable([base_rv]):
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_min: Max
if base_rv.type.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
measurable_min = MeasurableDiscreteMaxNeg(axis)
else:
measurable_min = MeasurableMaxNeg(list(axis))
measurable_min = MeasurableMaxNeg(axis)

return measurable_min.make_node(base_rv).outputs

Expand Down
7 changes: 0 additions & 7 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -374,12 +373,6 @@ def incsubtensor_rv_replace(fgraph, node):

logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic")

# Split max_and_argmax
# We only register this in the measurable IR db because max does not have a grad implemented
# And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251
# This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed
measurable_ir_rewrites_db.register("local_max_and_argmax", local_max_and_argmax, "basic")

# These rewrites push random/measurable variables "down", making them closer to
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
Expand Down
4 changes: 3 additions & 1 deletion pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
# SOFTWARE.


from pathlib import Path

import pytensor

from pytensor import tensor as pt
Expand Down Expand Up @@ -237,7 +239,7 @@ class MeasurableDimShuffle(DimShuffle):

# Need to get the absolute path of `c_func_file`, otherwise it tries to
# find it locally and fails when a new `Op` is initialized
c_func_file = DimShuffle.get_path(DimShuffle.c_func_file)
c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file)))


MeasurableVariable.register(MeasurableDimShuffle)
Expand Down
4 changes: 2 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
graph_inputs,
walk,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
Expand Down Expand Up @@ -897,7 +897,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
[client, _] = rng_clients[0]

# RNG is an output of the function, this is not a problem
if client == "output":
if isinstance(client.op, Output):
return rng

# RNG is used by another operator, which should output an update for the RNG
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor>=2.23,<2.24
pytensor>=2.25.1,<2.26
pytest-cov>=2.5
pytest>=3.0
rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cachetools>=4.2.1
cloudpickle
numpy>=1.15.0
pandas>=0.24.0
pytensor>=2.23,<2.24
pytensor>=2.25.1,<2.26
rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def test_skewstudentt_logcdf(self):
check_logcdf(
pm.SkewStudentT,
R,
{"a": Rplus, "b": Rplus, "mu": R, "sigma": Rplus},
{"a": Rplus, "b": Rplus, "mu": R, "sigma": Rplusbig},
lambda value, a, b, mu, sigma: st.jf_skew_t.logcdf(value, a, b, mu, sigma),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def test_batched_size(self, explicit_shape, batched_param):
with Model() as t0:
y = GARCH11("y", **kwargs0)

y_eval = draw(y, draws=2)
y_eval = draw(y, draws=2, random_seed=800)
assert y_eval[0].shape == (batch_size, steps)
assert not np.any(np.isclose(y_eval[0], y_eval[1]))

Expand Down