Skip to content

Commit 96060bf

Browse files
Fix tuple-related type hints
1 parent c299bc7 commit 96060bf

File tree

13 files changed

+23
-24
lines changed

13 files changed

+23
-24
lines changed

pytensor/graph/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ComputeMapType = Dict[Variable, List[bool]]
4242
InputStorageType = List[StorageCellType]
4343
OutputStorageType = List[StorageCellType]
44-
ParamsInputType = Optional[Tuple[Any]]
44+
ParamsInputType = Optional[Tuple[Any, ...]]
4545
PerformMethodType = Callable[
4646
[Apply, List[Any], OutputStorageType, ParamsInputType], None
4747
]

pytensor/link/c/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from abc import abstractmethod
3-
from typing import Callable, Dict, List, Tuple, Union
3+
from typing import Callable, Dict, List, Tuple
44

55
from pytensor.graph.basic import Apply, Constant
66
from pytensor.graph.utils import MethodNotDefined
@@ -149,7 +149,7 @@ def c_init_code(self, **kwargs) -> List[str]:
149149
"""Return a list of code snippets to be inserted in module initialization."""
150150
return []
151151

152-
def c_code_cache_version(self) -> Union[Tuple[int, ...], Tuple]:
152+
def c_code_cache_version(self) -> Tuple[int, ...]:
153153
"""Return a tuple of integers indicating the version of this `Op`.
154154
155155
An empty tuple indicates an "unversioned" `Op` that will not be cached
@@ -566,7 +566,7 @@ def c_cleanup(self, name: str, sub: Dict[str, str]) -> str:
566566
"""
567567
return ""
568568

569-
def c_code_cache_version(self) -> Union[Tuple, Tuple[int]]:
569+
def c_code_cache_version(self) -> Tuple[int, ...]:
570570
"""Return a tuple of integers indicating the version of this type.
571571
572572
An empty tuple indicates an "unversioned" type that will not

pytensor/link/c/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def lquote_macro(txt: str) -> str:
240240
return "\n".join(res)
241241

242242

243-
def get_sub_macros(sub: Dict[str, str]) -> Union[Tuple[str], Tuple[str, str]]:
243+
def get_sub_macros(sub: Dict[str, str]) -> Tuple[str, str]:
244244
define_macros = []
245245
undef_macros = []
246246
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
@@ -533,7 +533,7 @@ def format_c_function_args(self, inp: List[str], out: List[str]) -> str:
533533

534534
def get_c_macros(
535535
self, node: Apply, name: str, check_input: Optional[bool] = None
536-
) -> Union[Tuple[str], Tuple[str, str]]:
536+
) -> Tuple[str, str]:
537537
"Construct a pair of C ``#define`` and ``#undef`` code strings."
538538
define_template = "#define %s %s"
539539
undef_template = "#undef %s"

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def add_inner_in_expr(
123123
# These outer-inputs are indexed without offsets or storage wrap-around
124124
add_inner_in_expr(outer_in_name, 0, None)
125125

126-
inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict(
126+
inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict(
127127
zip(
128128
outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names,
129129
op.info.mit_mot_in_slices
@@ -157,7 +157,7 @@ def add_inner_in_expr(
157157
# storage array like a circular buffer, and that's why we need to track the
158158
# storage size along with the taps length/indexing offset.
159159
def add_output_storage_post_proc_stmt(
160-
outer_in_name: str, tap_sizes: Tuple[int], storage_size: str
160+
outer_in_name: str, tap_sizes: Tuple[int, ...], storage_size: str
161161
):
162162

163163
tap_size = max(tap_sizes)

pytensor/raise_op.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Symbolic Op for raising an exception."""
22

33
from textwrap import indent
4-
from typing import Tuple
54

65
import numpy as np
76

@@ -63,7 +62,7 @@ def __eq__(self, other):
6362
def __hash__(self):
6463
return hash((self.msg, self.exc_type))
6564

66-
def make_node(self, value: Variable, *conds: Tuple[Variable]):
65+
def make_node(self, value: Variable, *conds: Variable):
6766
"""
6867
6968
Parameters

pytensor/sandbox/multinomial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import Tuple, Union
2+
from typing import Tuple
33

44
import numpy as np
55

@@ -18,7 +18,7 @@ class MultinomialFromUniform(COp):
1818
TODO : need description for parameter 'odtype'
1919
"""
2020

21-
__props__: Union[Tuple[str], Tuple[str, str]] = ("odtype",)
21+
__props__: Tuple[str, ...] = ("odtype",)
2222

2323
def __init__(self, odtype):
2424
self.odtype = odtype

pytensor/scalar/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3998,7 +3998,7 @@ class Composite(ScalarOp, HasInnerGraph):
39983998
39993999
"""
40004000

4001-
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
4001+
init_param: Tuple[str, ...] = ("inputs", "outputs")
40024002

40034003
def __init__(self, inputs, outputs):
40044004
# We need to clone the graph as sometimes its nodes already

pytensor/tensor/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
except ImportError:
138138
pass
139139

140-
from typing import Tuple, Union
140+
from typing import Tuple
141141

142142
import pytensor.scalar
143143
from pytensor.compile.mode import optdb
@@ -522,7 +522,7 @@ class GemmRelated(COp):
522522
523523
"""
524524

525-
__props__: Union[Tuple, Tuple[str]] = ()
525+
__props__: Tuple[str, ...] = ()
526526

527527
def c_support_code(self, **kwargs):
528528
# return cblas_header_text()

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1763,7 +1763,7 @@ def linspace(start, end, steps):
17631763

17641764

17651765
def broadcast_to(
1766-
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]]
1766+
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable, ...]]
17671767
) -> TensorVariable:
17681768
"""Broadcast an array to a new shape.
17691769

pytensor/tensor/nlinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Tuple, Union
2+
from typing import Tuple
33

44
import numpy as np
55

@@ -238,7 +238,7 @@ class Eig(Op):
238238
"""
239239

240240
_numop = staticmethod(np.linalg.eig)
241-
__props__: Union[Tuple, Tuple[str]] = ()
241+
__props__: Tuple[str, ...] = ()
242242

243243
def make_node(self, x):
244244
x = as_tensor_variable(x)

0 commit comments

Comments
 (0)