Skip to content

Commit c8839c8

Browse files
ArmavicaricardoV94
authored andcommitted
Fix RUF015
1 parent d3dd34e commit c8839c8

File tree

9 files changed

+24
-24
lines changed

9 files changed

+24
-24
lines changed

pytensor/link/c/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ def unpickle_failure():
10381038
_logger.info(f"deleting ModuleCache entry {entry}")
10391039
key_data.delete_keys_from(self.entry_from_key)
10401040
del self.module_hash_to_key_data[module_hash]
1041-
if key_data.keys and list(key_data.keys)[0][0]:
1041+
if key_data.keys and next(iter(key_data.keys))[0]:
10421042
# this is a versioned entry, so should have been on
10431043
# disk. Something weird happened to cause this, so we
10441044
# are responding by printing a warning, removing

pytensor/tensor/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def broadcast_static_dim_lengths(
151151
dim_lengths_set = set(dim_lengths)
152152
# All dim_lengths are the same
153153
if len(dim_lengths_set) == 1:
154-
return tuple(dim_lengths_set)[0]
154+
return next(iter(dim_lengths_set))
155155

156156
# Only valid indeterminate case
157157
if dim_lengths_set == {None, 1}:
@@ -161,7 +161,7 @@ def broadcast_static_dim_lengths(
161161
dim_lengths_set.discard(None)
162162
if len(dim_lengths_set) > 1:
163163
raise ValueError
164-
return tuple(dim_lengths_set)[0]
164+
return next(iter(dim_lengths_set))
165165

166166

167167
# Copied verbatim from numpy.lib.function_base

tests/link/test_vm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_allow_gc_cvm():
275275
f = function([v], v + 1, mode=mode)
276276

277277
f([1])
278-
n = list(f.maker.fgraph.apply_nodes)[0].outputs[0]
278+
n = next(iter(f.maker.fgraph.apply_nodes)).outputs[0]
279279
assert f.vm.storage_map[n][0] is None
280280
assert f.vm.allow_gc is True
281281

tests/scan/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,7 +1630,7 @@ def reset_rng_grad_fn(*args):
16301630

16311631
# Also validate that the mappings outer_inp_from_outer_out and
16321632
# outer_inp_from_inner_inp produce the correct results
1633-
scan_node = list(updates.values())[0].owner
1633+
scan_node = next(iter(updates.values())).owner
16341634

16351635
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
16361636
result = var_mappings["outer_inp_from_outer_out"]
@@ -1922,7 +1922,7 @@ def inner_fn():
19221922
_, updates = scan(
19231923
inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False
19241924
)
1925-
cost = list(updates.values())[0]
1925+
cost = next(iter(updates.values()))
19261926
g_sh = grad(cost, shared_var)
19271927
fgrad = function([], g_sh)
19281928
assert fgrad() == 1

tests/scan/test_rewriting.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def lambda_fn(h, W1, W2):
270270

271271
f = function([h0, W1, W2], o, mode=self.mode)
272272

273-
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
273+
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
274274
assert (
275275
len(
276276
[
@@ -444,9 +444,9 @@ def test_dot_not_output(self):
444444
# Ensure that the optimization was performed correctly in f_opt
445445
# The inner function of scan should have only one output and it should
446446
# not be the result of a Dot
447-
scan_node = [
447+
scan_node = next(
448448
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
449-
][0]
449+
)
450450
assert len(scan_node.op.inner_outputs) == 1
451451
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
452452

@@ -488,9 +488,9 @@ def inner_fct(vect, mat):
488488
# Ensure that the optimization was performed correctly in f_opt
489489
# The inner function of scan should have only one output and it should
490490
# not be the result of a Dot
491-
scan_node = [
491+
scan_node = next(
492492
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
493-
][0]
493+
)
494494
# NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST
495495
# HAVE ONLY 1 OUTPUT.
496496
assert len(scan_node.op.inner_outputs) == 2
@@ -536,9 +536,9 @@ def inner_fct(seq1, previous_output1, nonseq1):
536536
# Ensure that the optimization was performed correctly in f_opt
537537
# The inner function of scan should have only one output and it should
538538
# not be the result of a Dot
539-
scan_node = [
539+
scan_node = next(
540540
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
541-
][0]
541+
)
542542
assert len(scan_node.op.inner_outputs) == 2
543543
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
544544

@@ -1639,7 +1639,7 @@ def lambda_fn(h, W1, W2):
16391639
)
16401640

16411641
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
1642-
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
1642+
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
16431643
assert (
16441644
len(
16451645
[
@@ -1673,7 +1673,7 @@ def lambda_fn(W1, h, W2):
16731673
)
16741674

16751675
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
1676-
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
1676+
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
16771677

16781678
assert (
16791679
len(
@@ -1709,7 +1709,7 @@ def lambda_fn(W1, h, W2):
17091709

17101710
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
17111711
f = function([_h0, _W1, _W2], o, mode="FAST_RUN")
1712-
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
1712+
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
17131713

17141714
assert len(scan_node.op.inner_inputs) == 1
17151715

tests/tensor/rewriting/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3274,7 +3274,7 @@ def test_local_reduce_broadcast_some_0(self):
32743274
order = f.maker.fgraph.toposort()
32753275
assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
32763276

3277-
node = [node for node in order if isinstance(node.op, CAReduce)][0]
3277+
node = next(node for node in order if isinstance(node.op, CAReduce))
32783278

32793279
op = node.op
32803280
assert isinstance(op, CAReduce)

tests/tensor/test_merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_merge_with_weird_eq():
7575
MergeOptimizer().rewrite(g)
7676

7777
assert len(g.apply_nodes) == 1
78-
node = list(g.apply_nodes)[0]
78+
node = next(iter(g.apply_nodes))
7979
assert len(node.inputs) == 2
8080
assert node.inputs[0] is node.inputs[1]
8181

@@ -87,6 +87,6 @@ def test_merge_with_weird_eq():
8787
MergeOptimizer().rewrite(g)
8888

8989
assert len(g.apply_nodes) == 1
90-
node = list(g.apply_nodes)[0]
90+
node = next(iter(g.apply_nodes))
9191
assert len(node.inputs) == 2
9292
assert node.inputs[0] is node.inputs[1]

tests/tensor/test_shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_bad_shape(self):
465465
f(xval)
466466

467467
assert isinstance(
468-
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
468+
next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape))
469469
.inputs[0]
470470
.type,
471471
self.input_type,
@@ -475,7 +475,7 @@ def test_bad_shape(self):
475475
xval = np.random.random((2, 3)).astype(config.floatX)
476476
f = pytensor.function([x], specify_shape(x, 2, 3), mode=self.mode)
477477
assert isinstance(
478-
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
478+
next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape))
479479
.inputs[0]
480480
.type,
481481
self.input_type,

tests/test_ifelse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_multiple_out(self):
194194
f = function([c, x1, x2, y1, y2], z, mode=self.mode)
195195
self.assertFunctionContains1(f, self.get_ifelse(2))
196196

197-
ifnode = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse)][0]
197+
ifnode = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse))
198198
assert len(ifnode.outputs) == 2
199199

200200
rng = np.random.default_rng(utt.fetch_seed())
@@ -369,7 +369,7 @@ def test_remove_useless_inputs1(self):
369369
z = ifelse(c, (x, x), (y, y))
370370
f = function([c, x, y], z)
371371

372-
ifnode = [n for n in f.maker.fgraph.toposort() if isinstance(n.op, IfElse)][0]
372+
ifnode = next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, IfElse))
373373
assert len(ifnode.inputs) == 3
374374

375375
@pytest.mark.skip(reason="Optimization temporarily disabled")
@@ -382,7 +382,7 @@ def test_remove_useless_inputs2(self):
382382
z = ifelse(c, (x1, x1, x1, x2, x2), (y1, y1, y2, y2, y2))
383383
f = function([c, x1, x2, y1, y2], z)
384384

385-
ifnode = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse)][0]
385+
ifnode = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse))
386386
assert len(ifnode.outputs) == 3
387387

388388
@pytest.mark.skip(reason="Optimization temporarily disabled")

0 commit comments

Comments
 (0)