Skip to content

Commit 8aeda39

Browse files
ArmavicaricardoV94
authored andcommitted
Fix RUF005
Automated fixes by RUF, and update of the TensorConstructorType type in scan/op.py because mypy didn't like something there.
1 parent b142cb5 commit 8aeda39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+392
-330
lines changed

pytensor/breakpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def make_node(self, condition, *monitored_vars):
9292
new_op.inp_types.append(monitored_vars[i].type)
9393

9494
# Build the Apply node
95-
inputs = [condition] + list(monitored_vars)
95+
inputs = [condition, *list(monitored_vars)]
9696
outputs = [inp.type() for inp in monitored_vars]
9797
return Apply(op=new_op, inputs=inputs, outputs=outputs)
9898

@@ -142,7 +142,7 @@ def perform(self, node, inputs, output_storage):
142142
output_storage[i][0] = inputs[i + 1]
143143

144144
def grad(self, inputs, output_gradients):
145-
return [DisconnectedType()()] + output_gradients
145+
return [DisconnectedType()(), *output_gradients]
146146

147147
def infer_shape(self, fgraph, inputs, input_shapes):
148148
# Return the shape of every input but the condition (first input)

pytensor/compile/debugmode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,9 +892,9 @@ def _get_preallocated_maps(
892892

893893
# Use the same step on all dimensions before the last check_ndim.
894894
if all(s == 1 for s in out_shape[:-check_ndim]):
895-
step_signs_list = [(1,)] + step_signs_list
895+
step_signs_list = [(1,), *step_signs_list]
896896
else:
897-
step_signs_list = [(-1, 1)] + step_signs_list
897+
step_signs_list = [(-1, 1), *step_signs_list]
898898

899899
for step_signs in itertools_product(*step_signs_list):
900900
for step_size in (1, 2):

pytensor/compile/sharedvalue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
209209
add_tag_trace(var)
210210
return var
211211
except MemoryError as e:
212-
e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",)
212+
e.args = (*e.args, "Consider using `pytensor.shared(..., borrow=True)`")
213213
raise
214214

215215

pytensor/configdefaults.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,8 @@ def add_caching_dir_configvars():
13821382
"fft_tiling",
13831383
"winograd",
13841384
"winograd_non_fused",
1385-
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
1385+
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
1386+
)
13861387

13871388
SUPPORTED_DNN_CONV_ALGO_BWD_DATA = (
13881389
"none",
@@ -1391,7 +1392,8 @@ def add_caching_dir_configvars():
13911392
"fft_tiling",
13921393
"winograd",
13931394
"winograd_non_fused",
1394-
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
1395+
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
1396+
)
13951397

13961398
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER = (
13971399
"none",
@@ -1400,7 +1402,8 @@ def add_caching_dir_configvars():
14001402
"small",
14011403
"winograd_non_fused",
14021404
"fft_tiling",
1403-
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
1405+
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
1406+
)
14041407

14051408
SUPPORTED_DNN_CONV_PRECISION = (
14061409
"as_input_f32",

pytensor/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,7 @@ def inner_function(*args):
19721972
jacobs, updates = pytensor.scan(
19731973
inner_function,
19741974
sequences=pytensor.tensor.arange(expression.shape[0]),
1975-
non_sequences=[expression] + wrt,
1975+
non_sequences=[expression, *wrt],
19761976
)
19771977
assert not updates, "Scan has returned a list of updates; this should not happen."
19781978
return as_list_or_tuple(using_list, using_tuple, jacobs)

pytensor/graph/features.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,13 @@ def consistent_(self, fgraph):
508508

509509

510510
class ReplaceValidate(History, Validator):
511-
pickle_rm_attr = (
512-
["replace_validate", "replace_all_validate", "replace_all_validate_remove"]
513-
+ History.pickle_rm_attr
514-
+ Validator.pickle_rm_attr
515-
)
511+
pickle_rm_attr = [
512+
"replace_validate",
513+
"replace_all_validate",
514+
"replace_all_validate_remove",
515+
*History.pickle_rm_attr,
516+
*Validator.pickle_rm_attr,
517+
]
516518

517519
def on_attach(self, fgraph):
518520
for attr in (

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def print_profile(cls, stream, prof, level=0):
405405
else:
406406
name = rewrite.name
407407
idx = rewrites.index(rewrite)
408-
ll.append((name, rewrite.__class__.__name__, idx) + nb_n)
408+
ll.append((name, rewrite.__class__.__name__, idx, *nb_n))
409409
lll = sorted(zip(prof, ll), key=lambda a: a[0])
410410

411411
for t, rewrite in lll[::-1]:
@@ -1138,7 +1138,8 @@ def decorator(f):
11381138
req = requirements
11391139
if inplace:
11401140
dh_handler = dh.DestroyHandler
1141-
req = tuple(requirements) + (
1141+
req = (
1142+
*tuple(requirements),
11421143
lambda fgraph: fgraph.attach_feature(dh_handler()),
11431144
)
11441145
rval = FromFunctionNodeRewriter(f, tracks, req)

pytensor/graph/rewriting/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def register(
7373

7474
if use_db_name_as_tag:
7575
if self.name is not None:
76-
tags = tags + (self.name,)
76+
tags = (*tags, self.name)
7777

7878
rewriter.name = name
7979
# This restriction is there because in many place we suppose that

pytensor/graph/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __init__(self, *args, **kwargs):
171171
assert list(kwargs.keys()) == ["variable"]
172172
error_msg = get_variable_trace_string(kwargs["variable"])
173173
if error_msg:
174-
args = args + (error_msg,)
174+
args = (*args, error_msg)
175175
s = "\n".join(args) # Needed to have the new line print correctly
176176
super().__init__(s)
177177

pytensor/ifelse.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
227227

228228
return Apply(
229229
self,
230-
[condition] + new_inputs_true_branch + new_inputs_false_branch,
230+
[condition, *new_inputs_true_branch, *new_inputs_false_branch],
231231
output_vars,
232232
)
233233

@@ -275,11 +275,11 @@ def grad(self, ins, grads):
275275
# condition + epsilon always triggers the same branch as condition
276276
condition_grad = condition.zeros_like().astype(config.floatX)
277277

278-
return (
279-
[condition_grad]
280-
+ if_true_op(*inputs_true_grad, return_list=True)
281-
+ if_false_op(*inputs_false_grad, return_list=True)
282-
)
278+
return [
279+
condition_grad,
280+
*if_true_op(*inputs_true_grad, return_list=True),
281+
*if_false_op(*inputs_false_grad, return_list=True),
282+
]
283283

284284
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
285285
cond = node.inputs[0]
@@ -397,7 +397,7 @@ def ifelse(
397397

398398
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name)
399399

400-
ins = [condition] + list(then_branch) + list(else_branch)
400+
ins = [condition, *list(then_branch), *list(else_branch)]
401401
rval = new_ifelse(*ins, return_list=True)
402402

403403
if rval_type is None:
@@ -611,7 +611,7 @@ def apply(self, fgraph):
611611
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
612612
pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
613613
pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
614-
new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
614+
new_ins = [merging_node.inputs[0], *mn_ts, *pl_ts, *mn_fs, *pl_fs]
615615
mn_name = "?"
616616
if merging_node.op.name:
617617
mn_name = merging_node.op.name
@@ -673,7 +673,7 @@ def cond_remove_identical(fgraph, node):
673673

674674
new_ifelse = IfElse(n_outs=len(nw_ts), as_view=op.as_view, name=op.name)
675675

676-
new_ins = [node.inputs[0]] + nw_ts + nw_fs
676+
new_ins = [node.inputs[0], *nw_ts, *nw_fs]
677677
new_outs = new_ifelse(*new_ins, return_list=True)
678678

679679
rval = []
@@ -711,7 +711,7 @@ def cond_merge_random_op(fgraph, main_node):
711711
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
712712
pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
713713
pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
714-
new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
714+
new_ins = [merging_node.inputs[0], *mn_ts, *pl_ts, *mn_fs, *pl_fs]
715715
mn_name = "?"
716716
if merging_node.op.name:
717717
mn_name = merging_node.op.name

0 commit comments

Comments
 (0)