1
- from collections .abc import Sequence
1
+ from collections .abc import Callable , Sequence
2
2
from functools import wraps
3
3
from itertools import zip_longest
4
4
from types import ModuleType
5
- from typing import TYPE_CHECKING , Literal
5
+ from typing import TYPE_CHECKING
6
6
7
7
import numpy as np
8
8
22
22
from pytensor .tensor .random .op import RandomVariable
23
23
24
24
25
- def params_broadcast_shapes (param_shapes , ndims_params , use_pytensor = True ):
25
+ def params_broadcast_shapes (
26
+ param_shapes : Sequence , ndims_params : Sequence [int ], use_pytensor : bool = True
27
+ ) -> list [tuple [int , ...]]:
26
28
"""Broadcast parameters that have different dimensions.
27
29
28
30
Parameters
@@ -36,12 +38,12 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
36
38
37
39
Returns
38
40
=======
39
- bcast_shapes : list of ndarray
41
+ bcast_shapes : list of tuples of ints
40
42
The broadcasted values of `params`.
41
43
"""
42
44
max_fn = maximum if use_pytensor else max
43
45
44
- rev_extra_dims = []
46
+ rev_extra_dims : list [ int ] = []
45
47
for ndim_param , param_shape in zip (ndims_params , param_shapes ):
46
48
# We need this in order to use `len`
47
49
param_shape = tuple (param_shape )
@@ -71,7 +73,9 @@ def max_bcast(x, y):
71
73
return bcast_shapes
72
74
73
75
74
- def broadcast_params (params , ndims_params ):
76
+ def broadcast_params (
77
+ params : Sequence [np .ndarray | TensorVariable ], ndims_params : Sequence [int ]
78
+ ) -> list [np .ndarray ]:
75
79
"""Broadcast parameters that have different dimensions.
76
80
77
81
>>> ndims_params = [1, 2]
@@ -215,7 +219,9 @@ def __init__(
215
219
self ,
216
220
seed : int | None = None ,
217
221
namespace : ModuleType | None = None ,
218
- rng_ctor : Literal [np .random .Generator ] = np .random .default_rng ,
222
+ rng_ctor : Callable [
223
+ [np .random .SeedSequence ], np .random .Generator
224
+ ] = np .random .default_rng ,
219
225
):
220
226
if namespace is None :
221
227
from pytensor .tensor .random import basic # pylint: disable=import-self
0 commit comments