Skip to content

Commit f9f930c

Browse files
ArmavicaricardoV94
authored andcommitted
Simplify some type checking
1 parent d62f4b1 commit f9f930c

File tree

6 files changed

+11
-25
lines changed

6 files changed

+11
-25
lines changed

pytensor/compile/function/pfunc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,7 @@ def construct_pfunc_ins_and_outs(
515515
if not isinstance(params, list | tuple):
516516
raise TypeError("The `params` argument must be a list or a tuple")
517517

518-
if not isinstance(no_default_updates, bool) and not isinstance(
519-
no_default_updates, list
520-
):
518+
if not isinstance(no_default_updates, bool | list):
521519
raise TypeError("The `no_default_update` argument must be a boolean or list")
522520

523521
if len(updates) > 0 and not all(

pytensor/compile/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(
207207
if implicit is None:
208208
from pytensor.compile.sharedvalue import SharedVariable
209209

210-
implicit = isinstance(value, Container) or isinstance(value, SharedVariable)
210+
implicit = isinstance(value, Container | SharedVariable)
211211
super().__init__(
212212
variable=variable,
213213
name=name,

pytensor/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1788,7 +1788,7 @@ def verify_grad(
17881788
o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd")
17891789
o_fn_out = o_fn(*[p.copy() for p in pt])
17901790

1791-
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
1791+
if isinstance(o_fn_out, tuple | list):
17921792
raise TypeError(
17931793
"It seems like you are trying to use verify_grad "
17941794
"on an Op or a function which outputs a list: there should"

pytensor/scan/basic.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ def _filter(x):
6868
"""
6969
# Is `x` a container we can iterate on?
7070
iter_on = None
71-
if isinstance(x, list) or isinstance(x, tuple):
71+
if isinstance(x, list | tuple):
7272
iter_on = x
7373
elif isinstance(x, dict):
7474
iter_on = x.items()
7575
if iter_on is not None:
7676
return all(_filter(y) for y in iter_on)
7777
else:
78-
return isinstance(x, Variable) or isinstance(x, until)
78+
return isinstance(x, Variable | until)
7979

8080
if not _filter(ls):
8181
raise ValueError(
@@ -840,11 +840,7 @@ def wrap_into_list(x):
840840
# add only the non-shared variables and non-constants to the arguments of
841841
# the dummy function [ a function should not get shared variables or
842842
# constants as input ]
843-
dummy_args = [
844-
arg
845-
for arg in args
846-
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
847-
]
843+
dummy_args = [arg for arg in args if not isinstance(arg, SharedVariable | Constant)]
848844
# when we apply the lambda expression we get a mixture of update rules
849845
# and outputs that needs to be separated
850846

@@ -1043,16 +1039,14 @@ def wrap_into_list(x):
10431039
other_inner_args = []
10441040

10451041
other_scan_args += [
1046-
arg
1047-
for arg in non_seqs
1048-
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
1042+
arg for arg in non_seqs if not isinstance(arg, SharedVariable | Constant)
10491043
]
10501044

10511045
# Step 5.6 all shared variables with no update rules
10521046
other_inner_args += [
10531047
safe_new(arg, "_copy")
10541048
for arg in non_seqs
1055-
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
1049+
if not isinstance(arg, SharedVariable | Constant)
10561050
]
10571051

10581052
inner_replacements.update(dict(zip(other_scan_args, other_inner_args)))

pytensor/tensor/basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,9 +1956,7 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
19561956
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
19571957
except NotScalarConstantError:
19581958
pass
1959-
if isinstance(x, ps.ScalarVariable) or isinstance(
1960-
x, ps.sharedvar.ScalarSharedVariable
1961-
):
1959+
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable):
19621960
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
19631961
x = x.owner.inputs[0]
19641962
else:

tests/tensor/rewriting/test_math.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,9 +2204,7 @@ def test_local_one_minus_erf(self):
22042204
assert len(topo) == 2
22052205
assert topo[0].op == erf
22062206
assert isinstance(topo[1].op, Elemwise)
2207-
assert isinstance(topo[1].op.scalar_op, ps.Add) or isinstance(
2208-
topo[1].op.scalar_op, ps.Sub
2209-
)
2207+
assert isinstance(topo[1].op.scalar_op, ps.Add | ps.Sub)
22102208

22112209
def test_local_erf_minus_one(self):
22122210
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
@@ -2227,9 +2225,7 @@ def test_local_erf_minus_one(self):
22272225
assert len(topo) == 2
22282226
assert topo[0].op == erf
22292227
assert isinstance(topo[1].op, Elemwise)
2230-
assert isinstance(topo[1].op.scalar_op, ps.Add) or isinstance(
2231-
topo[1].op.scalar_op, ps.Sub
2232-
)
2228+
assert isinstance(topo[1].op.scalar_op, ps.Add | ps.Sub)
22332229

22342230

22352231
@pytest.mark.skipif(

0 commit comments

Comments
 (0)