Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytensor
import pytensor.tensor as pt

from pytensor.graph.basic import Node, ancestors
from pytensor.graph.basic import Apply, ancestors
from pytensor.graph.replace import clone_replace
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -490,7 +490,7 @@ def step(*args):
constant_term=constant_term,
)(rhos, sigma, init_dist, steps, noise_rng)

def update(self, node: Node):
def update(self, node: Apply):
"""Return the update mapping for the noise RV."""
return {node.inputs[-1]: node.outputs[0]}

Expand Down Expand Up @@ -767,7 +767,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
outputs=[noise_next_rng, garch11],
)(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng)

def update(self, node: Node):
def update(self, node: Apply):
"""Return the update mapping for the noise RV."""
return {node.inputs[-1]: node.outputs[0]}

Expand Down Expand Up @@ -918,7 +918,7 @@ def step(*prev_args):
extended_signature=f"(),(s),{','.join('()' for _ in sde_pars)},[rng]->[rng],(t)",
)(init_dist, steps, *sde_pars, noise_rng)

def update(self, node: Node):
def update(self, node: Apply):
"""Return the update mapping for the noise RV."""
return {node.inputs[-1]: node.outputs[0]}

Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pytensor import config, graph_replace, scan
from pytensor.graph import Op
from pytensor.graph.basic import Node
from pytensor.graph.basic import Apply
from pytensor.raise_op import CheckAndRaise
from pytensor.scan import until
from pytensor.tensor import TensorConstant, TensorVariable
Expand Down Expand Up @@ -211,7 +211,7 @@ def _create_logcdf_exprs(
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf

def update(self, node: Node):
def update(self, node: Apply):
"""Return the update mapping for the internal RNGs.

TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.
Expand Down
13 changes: 7 additions & 6 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast

import numpy as np
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GE, GT, LE, LT, Invert
Expand All @@ -39,7 +40,7 @@ class MeasurableComparison(MeasurableElemwise):


@node_rewriter(tracks=[gt, lt, ge, le])
def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
def find_measurable_comparisons(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
measurable_inputs = filter_measurable_variables(node.inputs)

if len(measurable_inputs) != 1:
Expand All @@ -55,7 +56,7 @@ def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[Tenso

# Check that the other input is not potentially measurable, in which case this rewrite
# would be invalid
const = node.inputs[(measurable_var_idx + 1) % 2]
const = cast(TensorVariable, node.inputs[(measurable_var_idx + 1) % 2])

# check for potential measurability of const
if check_potential_measurability([const]):
Expand Down Expand Up @@ -127,11 +128,11 @@ class MeasurableBitwise(MeasurableElemwise):


@node_rewriter(tracks=[invert])
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
base_var = node.inputs[0]
def find_measurable_bitwise(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
base_var = cast(TensorVariable, node.inputs[0])

if not base_var.dtype.startswith("bool"):
raise None
return None

if not filter_measurable_variables([base_var]):
return None
Expand Down
6 changes: 3 additions & 3 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import numpy as np
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
Expand All @@ -62,7 +62,7 @@ class MeasurableClip(MeasurableElemwise):


@node_rewriter(tracks=[clip])
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
def find_measurable_clips(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

if not filter_measurable_variables(node.inputs):
Expand Down Expand Up @@ -153,7 +153,7 @@ class MeasurableRound(MeasurableElemwise):


@node_rewriter(tracks=[ceil, floor, round_half_to_even])
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
def find_measurable_roundings(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
if not filter_measurable_variables(node.inputs):
return None

Expand Down
4 changes: 2 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from pytensor import scan
from pytensor.gradient import jacobian
from pytensor.graph.basic import Node, Variable
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar import (
Expand Down Expand Up @@ -453,7 +453,7 @@ def measurable_power_exponent_to_exp(fgraph, node):
erfcx,
]
)
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] | None:
def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Variable] | None:
"""Find measurable transformations from Elemwise operators."""
# Node was already converted
if isinstance(node.op, MeasurableOp):
Expand Down
14 changes: 8 additions & 6 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"MultivariateNormalProposal",
]

from pymc.util import get_value_vars_from_user_vars
from pymc.util import RandomGenerator, get_value_vars_from_user_vars

# Available proposal distributions for Metropolis

Expand Down Expand Up @@ -302,7 +302,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
accept_rate = self.delta_logp(q, q0d)
q, accepted = metrop_select(accept_rate, q, q0d, rng=self.rng)
self.accept_rate_iter = accept_rate
self.accepted_iter = accepted
self.accepted_iter[0] = accepted
self.accepted_sum += accepted

self.steps_until_tune -= 1
Expand Down Expand Up @@ -622,14 +622,16 @@ class CategoricalGibbsMetropolis(ArrayStep):

_state_class = CategoricalGibbsMetropolisState

def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None):
def __init__(
self, vars, proposal="uniform", order="random", model=None, rng: RandomGenerator = None
):
model = pm.modelcontext(model)

vars = get_value_vars_from_user_vars(vars, model)

initial_point = model.initial_point()

dimcats = []
dimcats: list[tuple[int, int]] = []
# The above variable is a list of pairs (aggregate dimension, number
# of categories). For example, if vars = [x, y] with x being a 2-D
# variable with M categories and y being a 3-D variable with N
Expand Down Expand Up @@ -665,10 +667,10 @@ def __init__(self, vars, proposal="uniform", order="random", model=None, rng=Non
self.dimcats = [dimcats[j] for j in order]

if proposal == "uniform":
self.astep = self.astep_unif
self.astep = self.astep_unif # type: ignore[assignment]
elif proposal == "proportional":
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
self.astep = self.astep_prop
self.astep = self.astep_prop # type: ignore[assignment]
else:
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")

Expand Down
2 changes: 0 additions & 2 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
pymc/distributions/timeseries.py
pymc/distributions/truncated.py
pymc/initial_point.py
pymc/logprob/binary.py
pymc/logprob/censoring.py
pymc/logprob/basic.py
pymc/logprob/mixture.py
pymc/logprob/rewriting.py
Expand Down