Skip to content

Commit 8ccf28f

Browse files
committed
Fix wrong type hint Node should be Apply
1 parent 6710e28 commit 8ccf28f

File tree

6 files changed

+17
-18
lines changed

6 files changed

+17
-18
lines changed

pymc/distributions/timeseries.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytensor
2222
import pytensor.tensor as pt
2323

24-
from pytensor.graph.basic import Node, ancestors
24+
from pytensor.graph.basic import Apply, ancestors
2525
from pytensor.graph.replace import clone_replace
2626
from pytensor.tensor import TensorVariable
2727
from pytensor.tensor.random.op import RandomVariable
@@ -490,7 +490,7 @@ def step(*args):
490490
constant_term=constant_term,
491491
)(rhos, sigma, init_dist, steps, noise_rng)
492492

493-
def update(self, node: Node):
493+
def update(self, node: Apply):
494494
"""Return the update mapping for the noise RV."""
495495
return {node.inputs[-1]: node.outputs[0]}
496496

@@ -767,7 +767,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
767767
outputs=[noise_next_rng, garch11],
768768
)(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng)
769769

770-
def update(self, node: Node):
770+
def update(self, node: Apply):
771771
"""Return the update mapping for the noise RV."""
772772
return {node.inputs[-1]: node.outputs[0]}
773773

@@ -918,7 +918,7 @@ def step(*prev_args):
918918
extended_signature=f"(),(s),{','.join('()' for _ in sde_pars)},[rng]->[rng],(t)",
919919
)(init_dist, steps, *sde_pars, noise_rng)
920920

921-
def update(self, node: Node):
921+
def update(self, node: Apply):
922922
"""Return the update mapping for the noise RV."""
923923
return {node.inputs[-1]: node.outputs[0]}
924924

pymc/distributions/truncated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pytensor import config, graph_replace, scan
2121
from pytensor.graph import Op
22-
from pytensor.graph.basic import Node
22+
from pytensor.graph.basic import Apply
2323
from pytensor.raise_op import CheckAndRaise
2424
from pytensor.scan import until
2525
from pytensor.tensor import TensorConstant, TensorVariable
@@ -211,7 +211,7 @@ def _create_logcdf_exprs(
211211
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
212212
return lower_logcdf, upper_logcdf
213213

214-
def update(self, node: Node):
214+
def update(self, node: Apply):
215215
"""Return the update mapping for the internal RNGs.
216216
217217
TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.

pymc/logprob/binary.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import cast
1415

1516
import numpy as np
1617
import pytensor.tensor as pt
1718

18-
from pytensor.graph.basic import Node
19+
from pytensor.graph.basic import Apply
1920
from pytensor.graph.fg import FunctionGraph
2021
from pytensor.graph.rewriting.basic import node_rewriter
2122
from pytensor.scalar.basic import GE, GT, LE, LT, Invert
@@ -39,7 +40,7 @@ class MeasurableComparison(MeasurableElemwise):
3940

4041

4142
@node_rewriter(tracks=[gt, lt, ge, le])
42-
def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
43+
def find_measurable_comparisons(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
4344
measurable_inputs = filter_measurable_variables(node.inputs)
4445

4546
if len(measurable_inputs) != 1:
@@ -55,7 +56,7 @@ def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[Tenso
5556

5657
# Check that the other input is not potentially measurable, in which case this rewrite
5758
# would be invalid
58-
const = node.inputs[(measurable_var_idx + 1) % 2]
59+
const = cast(TensorVariable, node.inputs[(measurable_var_idx + 1) % 2])
5960

6061
# check for potential measurability of const
6162
if check_potential_measurability([const]):
@@ -127,8 +128,8 @@ class MeasurableBitwise(MeasurableElemwise):
127128

128129

129130
@node_rewriter(tracks=[invert])
130-
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
131-
base_var = node.inputs[0]
131+
def find_measurable_bitwise(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
132+
base_var = cast(TensorVariable, node.inputs[0])
132133

133134
if not base_var.dtype.startswith("bool"):
134135
return None

pymc/logprob/censoring.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import numpy as np
3939
import pytensor.tensor as pt
4040

41-
from pytensor.graph.basic import Node
41+
from pytensor.graph.basic import Apply
4242
from pytensor.graph.fg import FunctionGraph
4343
from pytensor.graph.rewriting.basic import node_rewriter
4444
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
@@ -62,7 +62,7 @@ class MeasurableClip(MeasurableElemwise):
6262

6363

6464
@node_rewriter(tracks=[clip])
65-
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
65+
def find_measurable_clips(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
6666
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
6767

6868
if not filter_measurable_variables(node.inputs):
@@ -153,7 +153,7 @@ class MeasurableRound(MeasurableElemwise):
153153

154154

155155
@node_rewriter(tracks=[ceil, floor, round_half_to_even])
156-
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
156+
def find_measurable_roundings(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
157157
if not filter_measurable_variables(node.inputs):
158158
return None
159159

pymc/logprob/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from pytensor import scan
4444
from pytensor.gradient import jacobian
45-
from pytensor.graph.basic import Node, Variable
45+
from pytensor.graph.basic import Apply, Variable
4646
from pytensor.graph.fg import FunctionGraph
4747
from pytensor.graph.rewriting.basic import node_rewriter
4848
from pytensor.scalar import (
@@ -453,7 +453,7 @@ def measurable_power_exponent_to_exp(fgraph, node):
453453
erfcx,
454454
]
455455
)
456-
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] | None:
456+
def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Variable] | None:
457457
"""Find measurable transformations from Elemwise operators."""
458458
# Node was already converted
459459
if isinstance(node.op, MeasurableOp):

scripts/run_mypy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
pymc/distributions/timeseries.py
3333
pymc/distributions/truncated.py
3434
pymc/initial_point.py
35-
pymc/logprob/binary.py
36-
pymc/logprob/censoring.py
3735
pymc/logprob/basic.py
3836
pymc/logprob/mixture.py
3937
pymc/logprob/rewriting.py

0 commit comments

Comments
 (0)