Skip to content

Commit f2edaf9

Browse files
committed
Apply rule RUF005
1 parent 0a83612 commit f2edaf9

File tree

9 files changed

+34
-32
lines changed

9 files changed

+34
-32
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,15 @@ def sample_fn(rng_key, size, dtype, *parameters):
324324
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
325325

326326
# Ravel the batch dimensions because vmap only works along a single axis
327-
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
327+
raveled_batch_a = a.reshape((-1, *a.shape[batch_ndim:]))
328328
if p is None:
329329
raveled_sample = jax.vmap(
330330
lambda key, a: jax.random.choice(
331331
key, a, shape=core_shape, replace=False, p=None
332332
)
333333
)(batch_sampling_keys, raveled_batch_a)
334334
else:
335-
raveled_batch_p = p.reshape((-1,) + p.shape[batch_ndim:])
335+
raveled_batch_p = p.reshape((-1, *p.shape[batch_ndim:]))
336336
raveled_sample = jax.vmap(
337337
lambda key, a, p: jax.random.choice(
338338
key, a, shape=core_shape, replace=False, p=p
@@ -363,7 +363,7 @@ def sample_fn(rng_key, size, dtype, *parameters):
363363
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
364364

365365
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
366-
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
366+
raveled_batch_x = x.reshape((-1, *x.shape[batch_ndim:]))
367367
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
368368
batch_sampling_keys, raveled_batch_x
369369
)

pytensor/scan/op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,7 +2104,7 @@ def perform(self, node, inputs, output_storage):
21042104
# are read and written.
21052105
# This way, there will be no information overwritten
21062106
# before it is read (as it used to happen).
2107-
shape = (pdx,) + output_storage[idx][0].shape[1:]
2107+
shape = (pdx, *output_storage[idx][0].shape[1:])
21082108
tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
21092109
tmp[:] = output_storage[idx][0][:pdx]
21102110
output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[
@@ -2113,7 +2113,7 @@ def perform(self, node, inputs, output_storage):
21132113
output_storage[idx][0][store_steps[idx] - pdx :] = tmp
21142114
del tmp
21152115
else:
2116-
shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:]
2116+
shape = (store_steps[idx] - pdx, *output_storage[idx][0].shape[1:])
21172117
tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
21182118
tmp[:] = output_storage[idx][0][pdx:]
21192119
output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[
@@ -2304,7 +2304,7 @@ def infer_shape(self, fgraph, node, input_shapes):
23042304
if x is None:
23052305
scan_outs.append(None)
23062306
else:
2307-
scan_outs.append((Shape_i(0)(o),) + x[1:])
2307+
scan_outs.append((Shape_i(0)(o), *x[1:]))
23082308
return scan_outs
23092309

23102310
def connection_pattern(self, node):

pytensor/tensor/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4051,8 +4051,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
40514051
# Re-order axes so they correspond to diagonals at axis1, axis2
40524052
axes = list(range(diag.type.ndim - 1))
40534053
last_idx = axes[-1]
4054-
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
4055-
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
4054+
axes = [*axes[:axis1], last_idx + 1, *axes[axis1:]]
4055+
axes = [*axes[:axis2], last_idx + 2, *axes[axis2:]]
40564056
result = result.transpose(axes)
40574057

40584058
return AllocDiag(
@@ -4525,7 +4525,7 @@ def _make_along_axis_idx(arr_shape, indices, axis):
45254525
if dim is None:
45264526
fancy_index.append(indices)
45274527
else:
4528-
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
4528+
ind_shape = (*shape_ones[:dim], -1, *shape_ones[dim + 1 :])
45294529
fancy_index.append(arange(n).reshape(ind_shape))
45304530

45314531
return tuple(fancy_index)

pytensor/tensor/conv/abstract_conv.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def get_conv_gradweights_shape(
244244
for i in range(len(subsample))
245245
)
246246
if unshared:
247-
return (nchan,) + top_shape[2:] + (nkern,) + out_shp
247+
return (nchan, *top_shape[2:], nkern, *out_shp)
248248
else:
249249
return (nchan, nkern, *out_shp)
250250

@@ -2906,9 +2906,9 @@ def perform(self, node, inp, out_):
29062906
def correct_for_groups(mat):
29072907
mshp0 = mat.shape[0] // self.num_groups
29082908
mshp1 = mat.shape[1] * self.num_groups
2909-
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
2909+
mat = mat.reshape((self.num_groups, mshp0, *mat.shape[1:]))
29102910
mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim)))
2911-
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim :])
2911+
mat = mat.reshape((mshp0, mshp1, *mat.shape[-self.convdim :]))
29122912
return mat
29132913

29142914
if self.num_groups > 1:
@@ -3283,7 +3283,7 @@ def perform(self, node, inp, out_):
32833283
def correct_for_groups(mat):
32843284
mshp0 = mat.shape[0] // self.num_groups
32853285
mshp1 = mat.shape[-self.convdim - 1] * self.num_groups
3286-
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
3286+
mat = mat.reshape((self.num_groups, mshp0, *mat.shape[1:]))
32873287
if self.unshared:
32883288
# for 2D -> (1, 2, 3, 0, 4, 5, 6)
32893289
mat = mat.transpose(
@@ -3294,14 +3294,16 @@ def correct_for_groups(mat):
32943294
)
32953295
)
32963296
mat = mat.reshape(
3297-
(mshp0,)
3298-
+ mat.shape[1 : 1 + self.convdim]
3299-
+ (mshp1,)
3300-
+ mat.shape[-self.convdim :]
3297+
(
3298+
mshp0,
3299+
*mat.shape[1 : 1 + self.convdim],
3300+
mshp1,
3301+
*mat.shape[-self.convdim :],
3302+
)
33013303
)
33023304
else:
33033305
mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim)))
3304-
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim :])
3306+
mat = mat.reshape((mshp0, mshp1, *mat.shape[-self.convdim :]))
33053307
return mat
33063308

33073309
kern = correct_for_groups(kern)

tests/compile/test_builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def test_make_node_shared(self):
563563
assert y_clone != y
564564
y_clone.name = "y_clone"
565565

566-
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
566+
out_new = test_ofg.make_node(*([*out.owner.inputs[:1], y_clone])).outputs[0]
567567

568568
assert "on_unused_input" in out_new.owner.op.kwargs
569569
assert out_new.owner.op.shared_inputs == [y_clone]

tests/scan/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ def prod(inputs):
144144
t = t.flatten()
145145
t[pos] += _eps
146146
t = t.reshape(pt[i].shape)
147-
f_eps = f(*(pt[:i] + [t] + pt[i + 1 :]))
147+
f_eps = f(*([*pt[:i], t, *pt[i + 1 :]]))
148148
_g.append(np.asarray((f_eps - f_x) / _eps))
149149
gx.append(np.asarray(_g).reshape(pt[i].shape))
150150
else:
151151
t = np.array(pt[i] + _eps)
152-
f_eps = f(*(pt[:i] + [t] + pt[i + 1 :]))
152+
f_eps = f(*([*pt[:i], t, *pt[i + 1 :]]))
153153
gx.append(np.asarray((f_eps - f_x) / _eps))
154154
self.gx = gx
155155

tests/tensor/conv/test_abstract_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_get_shape(self):
266266
computed_image_shape = get_conv_gradinputs_shape(
267267
kernel_shape, output_shape, b, (2, 3), (d, d)
268268
)
269-
image_shape_with_None = image_shape[:2] + (None, None)
269+
image_shape_with_None = (*image_shape[:2], None, None)
270270
assert computed_image_shape == image_shape_with_None
271271

272272
# compute the kernel_shape given this output_shape
@@ -276,7 +276,7 @@ def test_get_shape(self):
276276

277277
# if border_mode == 'half', the shape should be None
278278
if b == "half":
279-
kernel_shape_with_None = kernel_shape[:2] + (None, None)
279+
kernel_shape_with_None = (*kernel_shape[:2], None, None)
280280
assert computed_kernel_shape == kernel_shape_with_None
281281
else:
282282
assert computed_kernel_shape == kernel_shape
@@ -285,7 +285,7 @@ def test_get_shape(self):
285285
computed_kernel_shape = get_conv_gradweights_shape(
286286
kernel_shape, output_shape, b, (2, 3), (d, d)
287287
)
288-
kernel_shape_with_None = kernel_shape[:2] + (None, None)
288+
kernel_shape_with_None = (*kernel_shape[:2], None, None)
289289
assert computed_kernel_shape == kernel_shape_with_None
290290

291291

tests/tensor/test_extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def check(shape, index_ndim, mode, order):
10191019
)
10201020
# create some invalid indices to test the mode
10211021
if mode in ("wrap", "clip"):
1022-
multi_index = (multi_index[0] - 1,) + multi_index[1:]
1022+
multi_index = (multi_index[0] - 1, *multi_index[1:])
10231023
# test with scalars and higher-dimensional indices
10241024
if index_ndim == 0:
10251025
multi_index = tuple(i[-1] for i in multi_index)

tests/tensor/test_subtensor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,7 @@ def test_advanced1_inc_and_set(self):
12721272
if len(inc_shape) == len(data_shape) and (
12731273
len(inc_shapes) == 0 or inc_shape[0] != 1
12741274
):
1275-
inc_shape = (n_to_inc,) + inc_shape[1:]
1275+
inc_shape = (n_to_inc, *inc_shape[1:])
12761276

12771277
# Symbolic variable with increment value.
12781278
inc_var_static_shape = tuple(
@@ -2822,15 +2822,15 @@ def bcast_shape_tuple(x):
28222822
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:2]),
28232823
(
28242824
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
2825-
test_idx[:2] + (slice(None, None),),
2825+
(*test_idx[:2], slice(None, None)),
28262826
),
28272827
(
28282828
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
2829-
(slice(None, None),) + test_idx[:1],
2829+
(slice(None, None), *test_idx[:1]),
28302830
),
28312831
(
28322832
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
2833-
(slice(None, None), None) + test_idx[1:2],
2833+
(slice(None, None), None, *test_idx[1:2]),
28342834
),
28352835
(
28362836
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
@@ -2842,15 +2842,15 @@ def bcast_shape_tuple(x):
28422842
),
28432843
(
28442844
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
2845-
test_idx[:1] + (slice(None, None),) + test_idx[1:2],
2845+
(*test_idx[:1], slice(None, None), *test_idx[1:2]),
28462846
),
28472847
(
28482848
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
2849-
test_idx[:1] + (slice(None, None),) + test_idx[1:2] + (slice(None, None),),
2849+
(*test_idx[:1], slice(None, None), *test_idx[1:2], slice(None, None)),
28502850
),
28512851
(
28522852
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
2853-
test_idx[:1] + (None,) + test_idx[1:2],
2853+
(*test_idx[:1], None, *test_idx[1:2]),
28542854
),
28552855
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
28562856
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),

0 commit comments

Comments
 (0)