Skip to content

Commit 63da6d1

Browse files
committed
Fix types in tensor/random/utils.py and tensor/utils.py
1 parent 6a295b9 commit 63da6d1

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def random_fn(rng, mu, kappa):
250250

251251
@numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement)
252252
def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node):
253+
assert isinstance(op.signature, str)
253254
[core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1]
254255
core_shape_len = int(core_shape_len_sig)
255256
implicit_arange = op.ndims_params[0] == 0

pytensor/tensor/random/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from collections.abc import Sequence
1+
from collections.abc import Callable, Sequence
22
from functools import wraps
33
from itertools import zip_longest
44
from types import ModuleType
5-
from typing import TYPE_CHECKING, Literal
5+
from typing import TYPE_CHECKING
66

77
import numpy as np
88

@@ -22,7 +22,9 @@
2222
from pytensor.tensor.random.op import RandomVariable
2323

2424

25-
def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
25+
def params_broadcast_shapes(
26+
param_shapes: Sequence, ndims_params: Sequence[int], use_pytensor: bool = True
27+
) -> list[tuple[int, ...]]:
2628
"""Broadcast parameters that have different dimensions.
2729
2830
Parameters
@@ -36,12 +38,12 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
3638
3739
Returns
3840
=======
39-
bcast_shapes : list of ndarray
41+
bcast_shapes : list of tuples of ints
4042
The broadcasted values of `params`.
4143
"""
4244
max_fn = maximum if use_pytensor else max
4345

44-
rev_extra_dims = []
46+
rev_extra_dims: list[int] = []
4547
for ndim_param, param_shape in zip(ndims_params, param_shapes):
4648
# We need this in order to use `len`
4749
param_shape = tuple(param_shape)
@@ -71,7 +73,9 @@ def max_bcast(x, y):
7173
return bcast_shapes
7274

7375

74-
def broadcast_params(params, ndims_params):
76+
def broadcast_params(
77+
params: Sequence[np.ndarray | TensorVariable], ndims_params: Sequence[int]
78+
) -> list[np.ndarray]:
7579
"""Broadcast parameters that have different dimensions.
7680
7781
>>> ndims_params = [1, 2]
@@ -215,7 +219,9 @@ def __init__(
215219
self,
216220
seed: int | None = None,
217221
namespace: ModuleType | None = None,
218-
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
222+
rng_ctor: Callable[
223+
[np.random.SeedSequence], np.random.Generator
224+
] = np.random.default_rng,
219225
):
220226
if namespace is None:
221227
from pytensor.tensor.random import basic # pylint: disable=import-self

pytensor/tensor/utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from numpy.core.numeric import normalize_axis_tuple # type: ignore
77

88
import pytensor
9+
from pytensor.graph import FunctionGraph, Variable
910
from pytensor.utils import hash_from_code
1011

1112

12-
def hash_from_ndarray(data):
13+
def hash_from_ndarray(data) -> str:
1314
"""
1415
Return a hash from an ndarray.
1516
@@ -36,7 +37,9 @@ def hash_from_ndarray(data):
3637
)
3738

3839

39-
def shape_of_variables(fgraph, input_shapes):
40+
def shape_of_variables(
41+
fgraph: FunctionGraph, input_shapes
42+
) -> dict[Variable, tuple[int, ...]]:
4043
"""
4144
Compute the numeric shape of all intermediate variables given input shapes.
4245
@@ -73,16 +76,14 @@ def shape_of_variables(fgraph, input_shapes):
7376

7477
fgraph.attach_feature(ShapeFeature())
7578

79+
shape_feature = fgraph.shape_feature # type: ignore[attr-defined]
80+
7681
input_dims = [
77-
dimension
78-
for inp in fgraph.inputs
79-
for dimension in fgraph.shape_feature.shape_of[inp]
82+
dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp]
8083
]
8184

8285
output_dims = [
83-
dimension
84-
for shape in fgraph.shape_feature.shape_of.values()
85-
for dimension in shape
86+
dimension for shape in shape_feature.shape_of.values() for dimension in shape
8687
]
8788

8889
compute_shapes = pytensor.function(input_dims, output_dims)
@@ -100,10 +101,8 @@ def shape_of_variables(fgraph, input_shapes):
100101
sym_to_num_dict = dict(zip(output_dims, numeric_output_dims))
101102

102103
l = {}
103-
for var in fgraph.shape_feature.shape_of:
104-
l[var] = tuple(
105-
sym_to_num_dict[sym] for sym in fgraph.shape_feature.shape_of[var]
106-
)
104+
for var in shape_feature.shape_of:
105+
l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var])
107106
return l
108107

109108

@@ -177,7 +176,7 @@ def broadcast_static_dim_lengths(
177176

178177

179178
def _parse_gufunc_signature(
180-
signature,
179+
signature: str,
181180
) -> tuple[
182181
list[tuple[str, ...]], ...
183182
]: # mypy doesn't know it's alwayl a length two tuple

0 commit comments

Comments
 (0)