Skip to content

Commit ed5e480

Browse files
committed
Make all zips strict in tests
1 parent 26ba673 commit ed5e480

34 files changed

+186
-110
lines changed

tests/compile/function/test_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def test_copy_share_memory(self):
371371

372372
# Assert storages of SharedVariable without updates are shared
373373
for (input, _1, _2), here, there in zip(
374-
ori.indices, ori.input_storage, cpy.input_storage
374+
ori.indices, ori.input_storage, cpy.input_storage, strict=True
375375
):
376376
assert here.data is there.data
377377

@@ -467,7 +467,7 @@ def test_swap_SharedVariable_with_given(self):
467467
swap={train_x: test_x, train_y: test_y}, delete_updates=True
468468
)
469469

470-
for in1, in2 in zip(test_def.maker.inputs, test_cpy.maker.inputs):
470+
for in1, in2 in zip(test_def.maker.inputs, test_cpy.maker.inputs, strict=True):
471471
assert in1.value is in2.value
472472

473473
def test_copy_delete_updates(self):
@@ -899,7 +899,7 @@ def test_deepcopy(self):
899899
# print 'f.defaults = %s' % (f.defaults, )
900900
# print 'g.defaults = %s' % (g.defaults, )
901901
for (f_req, f_feed, f_val), (g_req, g_feed, g_val) in zip(
902-
f.defaults, g.defaults
902+
f.defaults, g.defaults, strict=True
903903
):
904904
assert f_req == g_req and f_feed == g_feed and f_val == g_val
905905

@@ -1070,7 +1070,7 @@ def test_optimizations_preserved(self):
10701070
tf = f.maker.fgraph.toposort()
10711071
tg = f.maker.fgraph.toposort()
10721072
assert len(tf) == len(tg)
1073-
for nf, ng in zip(tf, tg):
1073+
for nf, ng in zip(tf, tg, strict=True):
10741074
assert nf.op == ng.op
10751075
assert len(nf.inputs) == len(ng.inputs)
10761076
assert len(nf.outputs) == len(ng.outputs)

tests/compile/test_builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,5 +722,5 @@ def test_debugprint():
722722
└─ *2-<Matrix(float64, shape=(?, ?))> [id I]
723723
"""
724724

725-
for truth, out in zip(exp_res.split("\n"), lines):
725+
for truth, out in zip(exp_res.split("\n"), lines, strict=True):
726726
assert truth.strip() == out.strip()

tests/d3viz/test_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def setup_method(self):
1919
def node_counts(self, graph):
2020
node_types = [node.get_attributes()["node_type"] for node in graph.get_nodes()]
2121
a, b = np.unique(node_types, return_counts=True)
22-
nc = dict(zip(a, b))
22+
nc = dict(zip(a, b, strict=True))
2323
return nc
2424

2525
@pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"])

tests/graph/test_fg.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,22 @@ def test_pickle(self):
3131
s = pickle.dumps(func)
3232
new_func = pickle.loads(s)
3333

34-
assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs))
35-
assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs))
34+
assert all(
35+
type(a) is type(b)
36+
for a, b in zip(func.inputs, new_func.inputs, strict=True)
37+
)
38+
assert all(
39+
type(a) is type(b)
40+
for a, b in zip(func.outputs, new_func.outputs, strict=True)
41+
)
3642
assert all(
3743
type(a.op) is type(b.op) # noqa: E721
38-
for a, b in zip(func.apply_nodes, new_func.apply_nodes)
44+
for a, b in zip(func.apply_nodes, new_func.apply_nodes, strict=True)
45+
)
46+
assert all(
47+
a.type == b.type
48+
for a, b in zip(func.variables, new_func.variables, strict=True)
3949
)
40-
assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
4150

4251
def test_validate_inputs(self):
4352
var1 = op1()

tests/graph/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def __init__(self, inner_inputs, inner_outputs):
137137
if not isinstance(v, Constant)
138138
]
139139
outputs = clone_replace(inner_outputs, replace=input_replacements)
140-
_, inputs = zip(*input_replacements) if input_replacements else (None, [])
140+
_, inputs = (
141+
zip(*input_replacements, strict=True) if input_replacements else (None, [])
142+
)
141143
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
142144

143145
def make_node(self, *inputs):

tests/link/jax/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def compare_jax_and_py(
7979
py_res = pytensor_py_fn(*test_inputs)
8080

8181
if len(fgraph.outputs) > 1:
82-
for j, p in zip(jax_res, py_res):
82+
for j, p in zip(jax_res, py_res, strict=True):
8383
assert_fn(j, p)
8484
else:
8585
assert_fn(jax_res, py_res)

tests/link/jax/test_random.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def test_random_updates(rng_ctor):
6161
# Check that original rng variable content was not overwritten when calling jax_typify
6262
assert all(
6363
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
64-
for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__())
64+
for a, b in zip(
65+
rng.get_value().__getstate__(), original_value.__getstate__(), strict=True
66+
)
6567
)
6668

6769

@@ -92,7 +94,9 @@ def test_replaced_shared_rng_storage_order(noise_first):
9294
), "Test may need to be tweaked"
9395

9496
# Confirm that input_storage type and fgraph input order are aligned
95-
for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs):
97+
for storage, fgrapn_input in zip(
98+
f.input_storage, f.maker.fgraph.inputs, strict=True
99+
):
96100
assert storage.type == fgrapn_input.type
97101

98102
assert mu.get_value() == 1

tests/link/numba/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def assert_fn(x, y):
292292
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
293293

294294
if len(fn_outputs) > 1:
295-
for j, p in zip(numba_res, py_res):
295+
for j, p in zip(numba_res, py_res, strict=True):
296296
assert_fn(j, p)
297297
else:
298298
assert_fn(numba_res[0], py_res[0])

tests/link/numba/test_scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
488488

489489
ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE"))
490490
ref_res = ref_fn(*test.values())
491-
for numba_r, ref_r in zip(numba_res, ref_res):
491+
for numba_r, ref_r in zip(numba_res, ref_res, strict=True):
492492
np.testing.assert_array_almost_equal(numba_r, ref_r)
493493

494494
benchmark(numba_fn, *test.values())

tests/link/pytorch/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def compare_pytorch_and_py(
6565
py_res = pytensor_py_fn(*test_inputs)
6666

6767
if len(fgraph.outputs) > 1:
68-
for j, p in zip(pytorch_res, py_res):
68+
for j, p in zip(pytorch_res, py_res, strict=True):
6969
assert_fn(j.cpu(), p)
7070
else:
7171
assert_fn([pytorch_res[0].cpu()], py_res)

0 commit comments

Comments
 (0)