Skip to content

Commit a65c3df

Browse files
committed
Make zips strict in pytensor/link/numba
1 parent 0767523 commit a65c3df

File tree

10 files changed

+54
-30
lines changed

10 files changed

+54
-30
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def py_perform_return(inputs):
405405
def py_perform_return(inputs):
406406
return tuple(
407407
out_type.filter(out[0])
408-
for out_type, out in zip(output_types, py_perform(inputs))
408+
for out_type, out in zip(output_types, py_perform(inputs), strict=True)
409409
)
410410

411411
@numba_njit
@@ -568,7 +568,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
568568
func_conditions = [
569569
f"assert x.shape[{i}] == {shape_input_names}"
570570
for i, (shape_input, shape_input_names) in enumerate(
571-
zip(shape_inputs, shape_input_names)
571+
zip(shape_inputs, shape_input_names, strict=True)
572572
)
573573
if shape_input is not NoneConst
574574
]

pytensor/link/numba/dispatch/cython_support.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def arg_numba_types(self) -> list[DTypeLike]:
4545
def can_cast_args(self, args: list[DTypeLike]) -> bool:
4646
ok = True
4747
count = 0
48-
for name, dtype in zip(self.arg_names, self.arg_dtypes):
48+
for name, dtype in zip(self.arg_names, self.arg_dtypes, strict=True):
4949
if name == "__pyx_skip_dispatch":
5050
continue
5151
if len(args) <= count:
@@ -164,7 +164,10 @@ def __wrapper_address__(self):
164164
return self._func_ptr
165165

166166
def __call__(self, *args, **kwargs):
167-
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
167+
args = [
168+
dtype(arg)
169+
for arg, dtype in zip(args, self._signature.arg_dtypes, strict=True)
170+
]
168171
if self.has_pyx_skip_dispatch():
169172
output = self._pyfunc(*args[:-1], **kwargs)
170173
else:

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,10 @@ def elemwise(*inputs):
515515
inputs = [np.asarray(input) for input in inputs]
516516
inputs_bc = np.broadcast_arrays(*inputs)
517517
shape = inputs[0].shape
518-
for input, bc in zip(inputs, input_bc_patterns):
519-
for length, allow_bc, iter_length in zip(input.shape, bc, shape):
518+
for input, bc in zip(inputs, input_bc_patterns, strict=True):
519+
for length, allow_bc, iter_length in zip(
520+
input.shape, bc, shape, strict=True
521+
):
520522
if length == 1 and shape and iter_length != 1 and not allow_bc:
521523
raise ValueError("Broadcast not allowed.")
522524

@@ -529,11 +531,11 @@ def elemwise(*inputs):
529531
outs = scalar_op_fn(*vals)
530532
if not isinstance(outs, tuple):
531533
outs = (outs,)
532-
for out, out_val in zip(outputs, outs):
534+
for out, out_val in zip(outputs, outs, strict=True):
533535
out[idx] = out_val
534536

535537
outputs_summed = []
536-
for output, bc in zip(outputs, output_bc_patterns):
538+
for output, bc in zip(outputs, output_bc_patterns, strict=True):
537539
axes = tuple(np.nonzero(bc)[0])
538540
outputs_summed.append(output.sum(axes, keepdims=True))
539541
if len(outputs_summed) != 1:

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
186186

187187
new_arr = arr.T.astype(np.float64).copy()
188188
for i, b in enumerate(new_arr):
189-
for j, (d, v) in enumerate(zip(shape, b)):
189+
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
190190
if v < 0 or v >= d:
191191
mode_fn(new_arr, i, j, v, d)
192192

pytensor/link/numba/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def {scalar_op_fn_name}({input_names}):
118118
[
119119
f"direct_cast({i_name}, {i_tmp_dtype_name})"
120120
for i_name, i_tmp_dtype_name in zip(
121-
input_names, input_tmp_dtype_names.keys()
121+
input_names, input_tmp_dtype_names.keys(), strict=False
122122
)
123123
]
124124
)

pytensor/link/numba/dispatch/scan.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,11 @@ def add_inner_in_expr(
163163
op.info.mit_mot_in_slices
164164
+ op.info.mit_sot_in_slices
165165
+ op.info.sit_sot_in_slices,
166+
strict=True,
166167
)
167168
)
168169
inner_in_names_to_output_taps: dict[str, tuple[int, ...] | None] = dict(
169-
zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices)
170+
zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices, strict=True)
170171
)
171172

172173
# Inner-outputs consist of:
@@ -373,7 +374,12 @@ def add_output_storage_post_proc_stmt(
373374
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
374375

375376
inner_out_to_outer_out_stmts = "\n".join(
376-
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)]
377+
[
378+
f"{s} = {d}"
379+
for s, d in zip(
380+
inner_out_to_outer_in_stmts, inner_output_names, strict=True
381+
)
382+
]
377383
)
378384

379385
scan_op_src = f"""

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def block_diag(*arrs):
421421
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
422422

423423
r, c = 0, 0
424-
for arr, shape in zip(arrs, shapes):
424+
for arr, shape in zip(arrs, shapes): # noqa: B905
425425
rr, cc = shape
426426
out[r : r + rr, c : c + cc] = arr
427427
r += rr

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def advancedincsubtensor1_inplace(x, val, idxs):
158158
def advancedincsubtensor1_inplace(x, vals, idxs):
159159
if not len(idxs) == len(vals):
160160
raise ValueError("The number of indices and values must match.")
161-
for idx, val in zip(idxs, vals):
161+
for idx, val in zip(idxs, vals): # noqa: B905
162162
x[idx] = val
163163
return x
164164
else:
@@ -184,7 +184,7 @@ def advancedincsubtensor1_inplace(x, val, idxs):
184184
def advancedincsubtensor1_inplace(x, vals, idxs):
185185
if not len(idxs) == len(vals):
186186
raise ValueError("The number of indices and values must match.")
187-
for idx, val in zip(idxs, vals):
187+
for idx, val in zip(idxs, vals): # noqa: B905
188188
x[idx] += val
189189
return x
190190

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
3737
"\n".join(
3838
[
3939
f"{item_name} = to_scalar({shape_name})"
40-
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
40+
for item_name, shape_name in zip(
41+
shape_var_item_names, shape_var_names, strict=True
42+
)
4143
]
4244
),
4345
" " * 4,
@@ -71,7 +73,9 @@ def numba_funcify_Alloc(op, node, **kwargs):
7173
"\n".join(
7274
[
7375
f"{item_name} = to_scalar({shape_name})"
74-
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
76+
for item_name, shape_name in zip(
77+
shape_var_item_names, shape_var_names, strict=True
78+
)
7579
]
7680
),
7781
" " * 4,

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
4545
store_outputs = "\n".join(
4646
[
4747
f"{output}[...] = {inner_output}"
48-
for output, inner_output in zip(outputs, inner_outputs)
48+
for output, inner_output in zip(outputs, inner_outputs, strict=True)
4949
]
5050
)
5151
func_src = f"""
@@ -139,7 +139,7 @@ def _vectorized(
139139
)
140140

141141
core_input_types = []
142-
for input_type, bc_pattern in zip(input_types, input_bc_patterns):
142+
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
143143
core_ndim = input_type.ndim - len(bc_pattern)
144144
# TODO: Reconsider this
145145
if core_ndim == 0:
@@ -152,14 +152,18 @@ def _vectorized(
152152

153153
core_out_types = [
154154
types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C")
155-
for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types)
155+
for dtype, output_core_shape in zip(
156+
output_dtypes, output_core_shape_types, strict=True
157+
)
156158
]
157159

158160
out_types = [
159161
types.Array(
160162
numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C"
161163
)
162-
for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types)
164+
for dtype, output_core_shape in zip(
165+
output_dtypes, output_core_shape_types, strict=True
166+
)
163167
]
164168

165169
for output_idx, input_idx in inplace_pattern:
@@ -213,7 +217,7 @@ def codegen(
213217

214218
inputs = [
215219
arrayobj.make_array(ty)(ctx, builder, val)
216-
for ty, val in zip(input_types, inputs)
220+
for ty, val in zip(input_types, inputs, strict=True)
217221
]
218222
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
219223

@@ -285,7 +289,9 @@ def compute_itershape(
285289
if size is not None:
286290
shape = size
287291
for i in range(batch_ndim):
288-
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
292+
for j, (bc, in_shape) in enumerate(
293+
zip(broadcast_pattern, in_shapes, strict=True)
294+
):
289295
length = in_shape[i]
290296
if bc[i]:
291297
with builder.if_then(
@@ -320,7 +326,9 @@ def compute_itershape(
320326
else:
321327
# Size is implied by the broadcast pattern
322328
for i in range(batch_ndim):
323-
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
329+
for j, (bc, in_shape) in enumerate(
330+
zip(broadcast_pattern, in_shapes, strict=True)
331+
):
324332
length = in_shape[i]
325333
if bc[i]:
326334
with builder.if_then(
@@ -376,7 +384,7 @@ def make_outputs(
376384
one = ir.IntType(64)(1)
377385
inplace_dict = dict(inplace)
378386
for i, (core_shape, bc, dtype) in enumerate(
379-
zip(output_core_shapes, out_bc, dtypes)
387+
zip(output_core_shapes, out_bc, dtypes, strict=True)
380388
):
381389
if i in inplace_dict:
382390
output_arrays.append(inputs[inplace_dict[i]])
@@ -390,7 +398,8 @@ def make_outputs(
390398
# This is actually an internal numba function, I guess we could
391399
# call `numba.nd.unsafe.ndarray` instead?
392400
batch_shape = [
393-
length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc)
401+
length if not bc_dim else one
402+
for length, bc_dim in zip(iter_shape, bc, strict=True)
394403
]
395404
shape = batch_shape + core_shape
396405
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
@@ -460,10 +469,10 @@ def make_loop_call(
460469

461470
# Load values from input arrays
462471
input_vals = []
463-
for input, input_type, bc in zip(inputs, input_types, input_bc):
472+
for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True):
464473
core_ndim = input_type.ndim - len(bc)
465474

466-
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [
475+
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
467476
zero
468477
] * core_ndim
469478
ptr = cgutils.get_item_pointer2(
@@ -508,13 +517,13 @@ def make_loop_call(
508517

509518
# Create output slices to pass to inner func
510519
output_slices = []
511-
for output, output_type, bc in zip(outputs, output_types, output_bc):
520+
for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True):
512521
core_ndim = output_type.ndim - len(bc)
513522
size_type = output.shape.type.element # type: ignore
514523
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore
515524
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore
516525

517-
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [
526+
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
518527
zero
519528
] * core_ndim
520529
ptr = cgutils.get_item_pointer2(

0 commit comments

Comments
 (0)