Skip to content

Commit b0f5f6c

Browse files
committed
Remove uses of unique_name_generator in numba dispatch
It's more readable and avoids potential bugs when force_unique is not set to True
1 parent bd2929a commit b0f5f6c

File tree

2 files changed

+18
-50
lines changed

2 files changed

+18
-50
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
1515
from pytensor.link.utils import (
1616
get_name_for_object,
17-
unique_name_generator,
1817
)
1918
from pytensor.scalar.basic import (
2019
Add,
@@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
8180
scalar_func_numba = generate_fallback_impl(op, node, **kwargs)
8281

8382
scalar_op_fn_name = get_name_for_object(scalar_func_numba)
84-
83+
prefix = "x" if scalar_func_name != "x" else "y"
84+
input_names = [f"{prefix}{i}" for i in range(len(node.inputs))]
85+
input_signature = ", ".join(input_names)
8586
global_env = {"scalar_func_numba": scalar_func_numba}
8687

8788
if input_inner_dtypes is None and output_inner_dtype is None:
88-
unique_names = unique_name_generator(
89-
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
90-
)
91-
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
9289
if not has_pyx_skip_dispatch:
9390
scalar_op_src = f"""
94-
def {scalar_op_fn_name}({input_names}):
95-
return scalar_func_numba({input_names})
91+
def {scalar_op_fn_name}({input_signature}):
92+
return scalar_func_numba({input_signature})
9693
"""
9794
else:
9895
scalar_op_src = f"""
99-
def {scalar_op_fn_name}({input_names}):
100-
return scalar_func_numba({input_names}, np.intc(1))
96+
def {scalar_op_fn_name}({input_signature}):
97+
return scalar_func_numba({input_signature}, np.intc(1))
10198
"""
10299

103100
else:
@@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}):
108105
for i, i_dtype in enumerate(input_inner_dtypes)
109106
}
110107
global_env.update(input_tmp_dtype_names)
111-
112-
unique_names = unique_name_generator(
113-
[scalar_op_fn_name, "scalar_func_numba", *global_env.keys()],
114-
suffix_sep="_",
115-
)
116-
117-
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
118108
converted_call_args = ", ".join(
119109
f"direct_cast({i_name}, {i_tmp_dtype_name})"
120110
for i_name, i_tmp_dtype_name in zip(
@@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}):
123113
)
124114
if not has_pyx_skip_dispatch:
125115
scalar_op_src = f"""
126-
def {scalar_op_fn_name}({", ".join(input_names)}):
116+
def {scalar_op_fn_name}({input_signature}):
127117
return direct_cast(scalar_func_numba({converted_call_args}), output_dtype)
128118
"""
129119
else:
130120
scalar_op_src = f"""
131-
def {scalar_op_fn_name}({", ".join(input_names)}):
121+
def {scalar_op_fn_name}({input_signature}):
132122
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
133123
"""
134124

135125
scalar_op_fn = compile_numba_function_src(
136126
scalar_op_src,
137127
scalar_op_fn_name,
138-
{**globals(), **global_env},
128+
globals() | global_env,
139129
)
140130

141131
# Functions that call a function pointer can't be cached
@@ -157,8 +147,8 @@ def switch(condition, x, y):
157147

158148
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
159149
"""Create a Numba-compatible N-ary function from a binary function."""
160-
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
161-
input_names = [unique_names(v, force_unique=True) for v in inputs]
150+
var_prefix = "x" if binary_op_name != "x" else "y"
151+
input_names = [f"{var_prefix}{i}" for i in range(len(inputs))]
162152
input_signature = ", ".join(input_names)
163153
output_expr = binary_op.join(input_names)
164154

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
register_funcify_and_cache_key,
1111
register_funcify_default_op_cache_key,
1212
)
13-
from pytensor.link.utils import unique_name_generator
1413
from pytensor.tensor.basic import (
1514
Alloc,
1615
AllocEmpty,
@@ -28,15 +27,7 @@
2827

2928
@register_funcify_default_op_cache_key(AllocEmpty)
3029
def numba_funcify_AllocEmpty(op, node, **kwargs):
31-
global_env = {
32-
"np": np,
33-
"dtype": np.dtype(op.dtype),
34-
}
35-
36-
unique_names = unique_name_generator(
37-
["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
38-
)
39-
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
30+
shape_var_names = [f"sh{i}" for i in range(len(node.inputs))]
4031
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
4132
shapes_to_items_src = indent(
4233
"\n".join(
@@ -56,21 +47,15 @@ def allocempty({", ".join(shape_var_names)}):
5647
"""
5748

5849
alloc_fn = compile_numba_function_src(
59-
alloc_def_src, "allocempty", {**globals(), **global_env}
50+
alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)}
6051
)
6152

6253
return numba_basic.numba_njit(alloc_fn)
6354

6455

6556
@register_funcify_and_cache_key(Alloc)
6657
def numba_funcify_Alloc(op, node, **kwargs):
67-
global_env = {"np": np}
68-
69-
unique_names = unique_name_generator(
70-
["np", "alloc", "val_np", "val", "scalar_shape", "res"],
71-
suffix_sep="_",
72-
)
73-
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
58+
shape_var_names = [f"sh{i}" for i in range(len(node.inputs) - 1)]
7459
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
7560
shapes_to_items_src = indent(
7661
"\n".join(
@@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}):
10287
alloc_fn = compile_numba_function_src(
10388
alloc_def_src,
10489
"alloc",
105-
{**globals(), **global_env},
90+
globals() | {"np": np},
10691
)
10792

10893
cache_key = sha256(
@@ -207,14 +192,7 @@ def eye(N, M, k):
207192
@register_funcify_default_op_cache_key(MakeVector)
208193
def numba_funcify_MakeVector(op, node, **kwargs):
209194
dtype = np.dtype(op.dtype)
210-
211-
global_env = {"np": np, "dtype": dtype}
212-
213-
unique_names = unique_name_generator(
214-
["np"],
215-
suffix_sep="_",
216-
)
217-
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
195+
input_names = [f"x{i}" for i in range(len(node.inputs))]
218196

219197
def create_list_string(x):
220198
args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
@@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}):
228206
makevector_fn = compile_numba_function_src(
229207
makevector_def_src,
230208
"makevector",
231-
{**globals(), **global_env},
209+
globals() | {"np": np, "dtype": dtype},
232210
)
233211

234212
return numba_basic.numba_njit(makevector_fn)

0 commit comments

Comments
 (0)