Skip to content

Commit 516cf27

Browse files
michaelosthegetwiecki
authored andcommitted
Fix typing in logprob.tensor
1 parent cbf1591 commit 516cf27

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

pymc/logprob/rewriting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import warnings
3737

3838
from collections import deque
39-
from collections.abc import Sequence
39+
from collections.abc import Collection, Sequence
4040

4141
import pytensor.tensor as pt
4242

@@ -484,7 +484,7 @@ def cleanup_ir(vars: Sequence[Variable]) -> None:
484484

485485

486486
def assume_measured_ir_outputs(
487-
inputs: Sequence[TensorVariable], outputs: Sequence[TensorVariable]
487+
inputs: Collection[TensorVariable], outputs: Sequence[TensorVariable]
488488
) -> Sequence[TensorVariable]:
489489
"""Run IR rewrite assuming each output is measured.
490490

pymc/logprob/tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import pytensor
3939

4040
from pytensor import tensor as pt
41+
from pytensor.graph.fg import FunctionGraph
4142
from pytensor.graph.op import compute_test_value
4243
from pytensor.graph.rewriting.basic import node_rewriter
4344
from pytensor.tensor import TensorVariable
@@ -60,7 +61,7 @@
6061

6162

6263
@node_rewriter([Alloc])
63-
def naive_bcast_rv_lift(fgraph, node):
64+
def naive_bcast_rv_lift(fgraph: FunctionGraph, node):
6465
"""Lift an ``Alloc`` through a ``RandomVariable`` ``Op``.
6566
6667
XXX: This implementation simply broadcasts the ``RandomVariable``'s
@@ -226,6 +227,7 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None:
226227
measurable_stack = MeasurableJoin()(axis, *base_vars)
227228
else:
228229
measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars)
230+
assert isinstance(measurable_stack, TensorVariable)
229231

230232
return [measurable_stack]
231233

@@ -242,7 +244,7 @@ class MeasurableDimShuffle(DimShuffle):
242244

243245

244246
@_logprob.register(MeasurableDimShuffle)
245-
def logprob_dimshuffle(op, values, base_var, **kwargs):
247+
def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
246248
"""Compute the log-likelihood graph for a `MeasurableDimShuffle`."""
247249
(value,) = values
248250

@@ -300,6 +302,7 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
300302
measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)(
301303
base_var
302304
)
305+
assert isinstance(measurable_dimshuffle, TensorVariable)
303306

304307
return [measurable_dimshuffle]
305308

pymc/logprob/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import typing
3737
import warnings
3838

39-
from collections.abc import Container, Sequence
39+
from collections.abc import Container, Iterable, Sequence
4040

4141
import numpy as np
4242
import pytensor
@@ -173,7 +173,7 @@ def indices_from_subtensor(idx_list, indices):
173173

174174

175175
def check_potential_measurability(
176-
inputs: tuple[TensorVariable], valued_rvs: Container[TensorVariable]
176+
inputs: Iterable[TensorVariable], valued_rvs: Container[TensorVariable]
177177
) -> bool:
178178
valued_rvs = set(valued_rvs)
179179

scripts/run_mypy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
pymc/logprob/mixture.py
3737
pymc/logprob/rewriting.py
3838
pymc/logprob/scan.py
39-
pymc/logprob/tensor.py
4039
pymc/logprob/transform_value.py
4140
pymc/logprob/transforms.py
4241
pymc/logprob/utils.py

0 commit comments

Comments
 (0)