Skip to content

Commit 05cbafc

Browse files
author
Ian Schweer
committed
Share one name generator
1 parent d26798f commit 05cbafc

File tree

3 files changed

+9
-12
lines changed

3 files changed

+9
-12
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import inspect
22
import sys
3-
from collections import defaultdict
43
from functools import singledispatch
54
from types import NoneType
65

@@ -59,19 +58,12 @@ def pytorch_funcify_FunctionGraph(
5958
fgraph_name="pytorch_funcified_fgraph",
6059
**kwargs,
6160
):
62-
fnames_counter = defaultdict(int)
63-
6461
def conversion_func_register(*args, **kwargs):
6562
last_frame = inspect.currentframe().f_back
6663
functor = pytorch_funcify(*args, **kwargs)
6764
if last_frame:
6865
module = sys.modules[last_frame.f_globals["__name__"]]
69-
name = functor.__name__
70-
if fnames_counter[name]:
71-
name = f"{name}_{fnames_counter[name]}"
72-
print("Registering functor", name)
73-
setattr(module, name, functor)
74-
fnames_counter[functor.__name__] += 1
66+
setattr(module, kwargs["unique_name"](functor), functor)
7567
return functor
7668

7769
return fgraph_to_python(
@@ -191,7 +183,6 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
191183

192184
# Apply inner rewrites
193185
PYTORCH.optimizer(op.fgraph)
194-
195186
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
196187
return fgraph_fn
197188

pytensor/link/pytorch/linker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.graph.basic import Variable
44
from pytensor.link.basic import JITLinker
5+
from pytensor.link.utils import unique_name_generator
56

67

78
class PytorchLinker(JITLinker):
@@ -18,8 +19,10 @@ def output_filter(self, var: Variable, out: Any) -> Any:
1819
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1920
from pytensor.link.pytorch.dispatch import pytorch_funcify
2021

22+
generator = unique_name_generator(["torch_linker"])
23+
built_kwargs = {"unique_name": generator, **kwargs}
2124
return pytorch_funcify(
22-
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
25+
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
2326
)
2427

2528
def jit_compile(self, fn):

pytensor/link/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,8 @@ def fgraph_to_python(
704704
get_name_for_object
705705
A function used to provide names for the objects referenced within the
706706
generated function.
707+
unique_name
708+
A function to make random function names for generated code (access through kwargs)
707709
squeeze_output
708710
If the `FunctionGraph` has only one output and this option is
709711
``True``, return the single output instead of a tuple with the output.
@@ -717,7 +719,8 @@ def fgraph_to_python(
717719
if storage_map is None:
718720
storage_map = {}
719721

720-
unique_name = unique_name_generator([fgraph_name])
722+
unique_name = kwargs.get("unique_name", unique_name_generator([fgraph_name]))
723+
kwargs["unique_name"] = unique_name
721724

722725
if global_env is None:
723726
global_env = {}

0 commit comments

Comments
 (0)