diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 6469cd101..14fcd1ec8 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -21,7 +21,7 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Node, ancestors +from pytensor.graph.basic import Apply, ancestors from pytensor.graph.replace import clone_replace from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable @@ -490,7 +490,7 @@ def step(*args): constant_term=constant_term, )(rhos, sigma, init_dist, steps, noise_rng) - def update(self, node: Node): + def update(self, node: Apply): """Return the update mapping for the noise RV.""" return {node.inputs[-1]: node.outputs[0]} @@ -767,7 +767,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): outputs=[noise_next_rng, garch11], )(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng) - def update(self, node: Node): + def update(self, node: Apply): """Return the update mapping for the noise RV.""" return {node.inputs[-1]: node.outputs[0]} @@ -918,7 +918,7 @@ def step(*prev_args): extended_signature=f"(),(s),{','.join('()' for _ in sde_pars)},[rng]->[rng],(t)", )(init_dist, steps, *sde_pars, noise_rng) - def update(self, node: Node): + def update(self, node: Apply): """Return the update mapping for the noise RV.""" return {node.inputs[-1]: node.outputs[0]} diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 6f32918bb..36b439526 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -19,7 +19,7 @@ from pytensor import config, graph_replace, scan from pytensor.graph import Op -from pytensor.graph.basic import Node +from pytensor.graph.basic import Apply from pytensor.raise_op import CheckAndRaise from pytensor.scan import until from pytensor.tensor import TensorConstant, TensorVariable @@ -211,7 +211,7 @@ def _create_logcdf_exprs( upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) return lower_logcdf, upper_logcdf - def update(self, node: Node): + def update(self, node: Apply): """Return the update mapping for the internal RNGs. TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs. diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 0767d25f8..9d0985a2c 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast import numpy as np import pytensor.tensor as pt -from pytensor.graph.basic import Node +from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import GE, GT, LE, LT, Invert @@ -39,7 +40,7 @@ class MeasurableComparison(MeasurableElemwise): @node_rewriter(tracks=[gt, lt, ge, le]) -def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: +def find_measurable_comparisons(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: measurable_inputs = filter_measurable_variables(node.inputs) if len(measurable_inputs) != 1: @@ -55,7 +56,7 @@ def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[Tenso # Check that the other input is not potentially measurable, in which case this rewrite # would be invalid - const = node.inputs[(measurable_var_idx + 1) % 2] + const = cast(TensorVariable, node.inputs[(measurable_var_idx + 1) % 2]) # check for potential measurability of const if check_potential_measurability([const]): @@ -127,11 +128,11 @@ class MeasurableBitwise(MeasurableElemwise): @node_rewriter(tracks=[invert]) -def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: - base_var = node.inputs[0] +def find_measurable_bitwise(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: + base_var = cast(TensorVariable, node.inputs[0]) if not base_var.dtype.startswith("bool"): - raise None + return None if not filter_measurable_variables([base_var]): return None diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 2104ecb6e..e17d30a43 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -38,7 +38,7 @@ import numpy as np import pytensor.tensor as pt -from pytensor.graph.basic import Node +from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven @@ -62,7 +62,7 @@ class MeasurableClip(MeasurableElemwise): @node_rewriter(tracks=[clip]) -def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: +def find_measurable_clips(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) if not filter_measurable_variables(node.inputs): @@ -153,7 +153,7 @@ class MeasurableRound(MeasurableElemwise): @node_rewriter(tracks=[ceil, floor, round_half_to_even]) -def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: +def find_measurable_roundings(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: if not filter_measurable_variables(node.inputs): return None diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 41233223b..930bf1f4e 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -42,7 +42,7 @@ from pytensor import scan from pytensor.gradient import jacobian -from pytensor.graph.basic import Node, Variable +from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar import ( @@ -453,7 +453,7 @@ def measurable_power_exponent_to_exp(fgraph, node): erfcx, ] ) -def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] | None: +def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Variable] | None: """Find measurable transformations from Elemwise operators.""" # Node was already converted if isinstance(node.op, MeasurableOp): diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index d825c8857..a7c7a36c6 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -59,7 +59,7 @@ "MultivariateNormalProposal", ] -from pymc.util import get_value_vars_from_user_vars +from pymc.util import RandomGenerator, get_value_vars_from_user_vars # Available proposal distributions for Metropolis @@ -302,7 +302,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: accept_rate = self.delta_logp(q, q0d) q, accepted = metrop_select(accept_rate, q, q0d, rng=self.rng) self.accept_rate_iter = accept_rate - self.accepted_iter = accepted + self.accepted_iter[0] = accepted self.accepted_sum += accepted self.steps_until_tune -= 1 @@ -622,14 +622,16 @@ class CategoricalGibbsMetropolis(ArrayStep): _state_class = CategoricalGibbsMetropolisState - def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None): + def __init__( + self, vars, proposal="uniform", order="random", model=None, rng: RandomGenerator = None + ): model = pm.modelcontext(model) vars = get_value_vars_from_user_vars(vars, model) initial_point = model.initial_point() - dimcats = [] + dimcats: list[tuple[int, int]] = [] # The above variable is a list of pairs (aggregate dimension, number # of categories). For example, if vars = [x, y] with x being a 2-D # variable with M categories and y being a 3-D variable with N @@ -665,10 +667,10 @@ def __init__(self, vars, proposal="uniform", order="random", model=None, rng=Non self.dimcats = [dimcats[j] for j in order] if proposal == "uniform": - self.astep = self.astep_unif + self.astep = self.astep_unif # type: ignore[assignment] elif proposal == "proportional": # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. - self.astep = self.astep_prop + self.astep = self.astep_prop # type: ignore[assignment] else: raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 842fb0a13..ccce06736 100755 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -32,8 +32,6 @@ pymc/distributions/timeseries.py pymc/distributions/truncated.py pymc/initial_point.py -pymc/logprob/binary.py -pymc/logprob/censoring.py pymc/logprob/basic.py pymc/logprob/mixture.py pymc/logprob/rewriting.py