Skip to content

Commit 10f285a

Browse files
committed
Use generators when appropriate
1 parent 8ae2a19 commit 10f285a

31 files changed

+96
-134
lines changed

pytensor/configparser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_config_hash(self):
104104
)
105105
return hash_from_code(
106106
"\n".join(
107-
[f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts]
107+
f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts
108108
)
109109
)
110110

pytensor/d3viz/formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def dict_to_pdnode(d):
360360
for k, v in d.items():
361361
if v is not None:
362362
if isinstance(v, list):
363-
v = "\t".join([str(x) for x in v])
363+
v = "\t".join(str(x) for x in v)
364364
else:
365365
v = str(v)
366366
v = str(v)

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ def __str__(self):
12641264
return getattr(
12651265
self,
12661266
"__name__",
1267-
f"{type(self).__name__}({','.join([str(o) for o in self.rewrites])})",
1267+
f"{type(self).__name__}({','.join(str(o) for o in self.rewrites)})",
12681268
)
12691269

12701270
def tracks(self):
@@ -1666,7 +1666,7 @@ def pattern_to_str(pattern):
16661666
if isinstance(pattern, list | tuple):
16671667
return "{}({})".format(
16681668
str(pattern[0]),
1669-
", ".join([pattern_to_str(p) for p in pattern[1:]]),
1669+
", ".join(pattern_to_str(p) for p in pattern[1:]),
16701670
)
16711671
elif isinstance(pattern, dict):
16721672
return "{} subject to {}".format(
@@ -2569,7 +2569,7 @@ def print_profile(cls, stream, prof, level=0):
25692569
d = sorted(
25702570
loop_process_count[i].items(), key=lambda a: a[1], reverse=True
25712571
)
2572-
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
2572+
loop_times = " ".join(str((str(k), v)) for k, v in d[:5])
25732573
if len(d) > 5:
25742574
loop_times += " ..."
25752575
print(

pytensor/link/c/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub):
235235
behavior = code_gen(blocks)
236236

237237
# declares the storage
238-
storage_decl = "\n".join([f"PyObject* {arg};" for arg in args])
238+
storage_decl = "\n".join(f"PyObject* {arg};" for arg in args)
239239
# in the constructor, sets the storage to the arguments
240-
storage_set = "\n".join([f"this->{arg} = {arg};" for arg in args])
240+
storage_set = "\n".join(f"this->{arg} = {arg};" for arg in args)
241241
# increments the storage's refcount in the constructor
242-
storage_incref = "\n".join([f"Py_XINCREF({arg});" for arg in args])
242+
storage_incref = "\n".join(f"Py_XINCREF({arg});" for arg in args)
243243
# decrements the storage's refcount in the destructor
244-
storage_decref = "\n".join([f"Py_XDECREF(this->{arg});" for arg in args])
244+
storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args)
245245

246246
args_names = ", ".join(args)
247-
args_decl = ", ".join([f"PyObject* {arg}" for arg in args])
247+
args_decl = ", ".join(f"PyObject* {arg}" for arg in args)
248248

249249
# The following code stores the exception data in __ERROR, which
250250
# is a special field of the struct. __ERROR is a list of length 3

pytensor/link/c/cmodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,7 @@ def try_blas_flag(flags):
20032003
cflags = list(flags)
20042004
# to support path that includes spaces, we need to wrap it with double quotes on Windows
20052005
path_wrapper = '"' if os.name == "nt" else ""
2006-
cflags.extend([f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs()])
2006+
cflags.extend(f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs())
20072007

20082008
res = GCC_compiler.try_compile_tmp(
20092009
test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True
@@ -2573,8 +2573,8 @@ def compile_str(
25732573
cmd.extend(preargs)
25742574
# to support path that includes spaces, we need to wrap it with double quotes on Windows
25752575
path_wrapper = '"' if os.name == "nt" else ""
2576-
cmd.extend([f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs])
2577-
cmd.extend([f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs])
2576+
cmd.extend(f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs)
2577+
cmd.extend(f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs)
25782578
if hide_symbols and sys.platform != "win32":
25792579
# This has been available since gcc 4.0 so we suppose it
25802580
# is always available. We pass it here since it

pytensor/link/c/params_type.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ def __init__(self, params_type, **kwargs):
263263

264264
def __repr__(self):
265265
return "Params({})".format(
266-
", ".join(
267-
[(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
268-
)
266+
", ".join((f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self))
269267
)
270268

271269
def __getattr__(self, key):
@@ -425,9 +423,7 @@ def __getattr__(self, key):
425423

426424
def __repr__(self):
427425
return "ParamsType<{}>".format(
428-
", ".join(
429-
[(f"{self.fields[i]}:{self.types[i]}") for i in range(self.length)]
430-
)
426+
", ".join((f"{self.fields[i]}:{self.types[i]}") for i in range(self.length))
431427
)
432428

433429
def __eq__(self, other):
@@ -748,10 +744,8 @@ def c_support_code(self, **kwargs):
748744
}}
749745
""".format(
750746
"\n".join(
751-
[
752-
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
753-
for i in range(self.length)
754-
]
747+
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
748+
for i in range(self.length)
755749
)
756750
)
757751
final_struct_code = """

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs):
485485
nout = len(node.outputs)
486486
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
487487

488-
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
489-
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
488+
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
489+
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
490490
output_dtypes = tuple(out.type.dtype for out in node.outputs)
491491
inplace_pattern = tuple(op.inplace_pattern.items())
492492
core_output_shapes = tuple(() for _ in range(nout))

pytensor/link/numba/dispatch/scalar.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
8585
unique_names = unique_name_generator(
8686
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
8787
)
88-
input_names = ", ".join(
89-
[unique_names(v, force_unique=True) for v in node.inputs]
90-
)
88+
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
9189
if not has_pyx_skip_dispatch:
9290
scalar_op_src = f"""
9391
def {scalar_op_fn_name}({input_names}):
@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}):
115113

116114
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
117115
converted_call_args = ", ".join(
118-
[
119-
f"direct_cast({i_name}, {i_tmp_dtype_name})"
120-
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
121-
]
116+
f"direct_cast({i_name}, {i_tmp_dtype_name})"
117+
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
122118
)
123119
if not has_pyx_skip_dispatch:
124120
scalar_op_src = f"""

pytensor/link/numba/dispatch/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def add_output_storage_post_proc_stmt(
373373
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
374374

375375
inner_out_to_outer_out_stmts = "\n".join(
376-
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)]
376+
f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)
377377
)
378378

379379
scan_op_src = f"""

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
3535
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
3636
shapes_to_items_src = indent(
3737
"\n".join(
38-
[
39-
f"{item_name} = to_scalar({shape_name})"
40-
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
41-
]
38+
f"{item_name} = to_scalar({shape_name})"
39+
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
4240
),
4341
" " * 4,
4442
)
@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs):
6967
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
7068
shapes_to_items_src = indent(
7169
"\n".join(
72-
[
73-
f"{item_name} = to_scalar({shape_name})"
74-
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
75-
]
70+
f"{item_name} = to_scalar({shape_name})"
71+
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
7672
),
7773
" " * 4,
7874
)

0 commit comments

Comments
 (0)