Skip to content

Commit 7c07f5d

Browse files
committed
More general type-hints
1 parent b935d0d commit 7c07f5d

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

pymc/distributions/shape_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,17 +230,17 @@ def rv_size_is_none(size: TensorVariable | Constant | None) -> bool:
230230

231231

232232
@singledispatch
233-
def _change_dist_size(op: Op, dist: TensorVariable, new_size, expand):
233+
def _change_dist_size(op: Op, dist: Variable, new_size, expand):
234234
raise NotImplementedError(
235235
f"Variable {dist} of type {op} has no _change_dist_size implementation."
236236
)
237237

238238

239239
def change_dist_size(
240-
dist: TensorVariable,
240+
dist: Variable,
241241
new_size: PotentialShapeType,
242242
expand: bool = False,
243-
) -> TensorVariable:
243+
) -> Variable:
244244
"""Change or expand the size of a Distribution.
245245
246246
Parameters

pymc/sampling/forward.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pytensor.graph.traversal import ancestors, general_toposort, walk
4343
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
4444
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
45-
from pytensor.tensor.variable import TensorConstant, TensorVariable
45+
from pytensor.tensor.variable import TensorConstant
4646
from rich.console import Console
4747
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4848
from rich.theme import Theme
@@ -1032,10 +1032,8 @@ def vectorize_over_posterior(
10321032
If random variables are found in the graph and `allow_rvs_in_graph` is False
10331033
"""
10341034
# Identify which free RVs are needed to compute `outputs`
1035-
needed_rvs: list[TensorVariable] = [
1036-
cast(TensorVariable, rv)
1037-
for rv in ancestors(outputs, blockers=input_rvs)
1038-
if rv in set(input_rvs)
1035+
needed_rvs: list[Variable] = [
1036+
rv for rv in ancestors(outputs, blockers=input_rvs) if rv in set(input_rvs)
10391037
]
10401038

10411039
# Replace needed_rvs with actual posterior samples
@@ -1044,7 +1042,7 @@ def vectorize_over_posterior(
10441042
for rv in needed_rvs:
10451043
posterior_samples = posterior[rv.name].data
10461044

1047-
replace_dict[rv] = pt.constant(posterior_samples.astype(rv.dtype), name=rv.name)
1045+
replace_dict[rv] = pt.constant(posterior_samples.astype(rv.dtype), name=rv.name) # type: ignore[attr-defined]
10481046

10491047
# Replace the rvs that remain in the graph with resized versions
10501048
all_rvs = rvs_in_graph(outputs)

0 commit comments

Comments
 (0)