Skip to content

Commit dc908cb

Browse files
author
ischweer
committed
Fix dangling reference
1 parent 96ec531 commit dc908cb

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,13 @@ def pytorch_funcify_FunctionGraph(
6161
conversion_func=pytorch_funcify,
6262
**kwargs,
6363
):
64-
def constants_wrapper(x, **kwargs):
65-
x = pytorch_typify(x)
66-
67-
@torch.compiler.assume_constant_result
68-
def torch_assume_constant(arg=x):
69-
return arg
70-
71-
return torch_assume_constant
64+
if "type_conversion_fn" not in kwargs:
65+
kwargs["type_conversion_fn"] = pytorch_typify
7266

7367
built_kwargs = {"conversion_func": conversion_func, **kwargs}
7468
return fgraph_to_python(
7569
fgraph,
7670
conversion_func,
77-
type_conversion_fn=constants_wrapper,
7871
fgraph_name=fgraph_name,
7972
**built_kwargs,
8073
)

pytensor/link/pytorch/linker.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def __init__(self, *args, **kwargs):
1010
self.gen_functors = []
1111

1212
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
13-
from pytensor.link.pytorch.dispatch import pytorch_funcify
13+
import torch
14+
15+
from pytensor.link.pytorch.dispatch import pytorch_funcify, pytorch_typify
1416

1517
# We want to have globally unique names
1618
# across the entire pytensor graph, not
@@ -25,9 +27,21 @@ def conversion_func_register(*args, **kwargs):
2527
self.gen_functors.append((f"_{name}", functor))
2628
return functor
2729

30+
def constants_wrapper(x, **kwargs):
31+
x = pytorch_typify(x)
32+
33+
@torch.compiler.assume_constant_result
34+
def torch_assume_constant(arg=x):
35+
return arg
36+
37+
name = kwargs["unique_name"](torch_assume_constant)
38+
self.gen_functors.append((f"_{name}", torch_assume_constant))
39+
return torch_assume_constant
40+
2841
built_kwargs = {
2942
"unique_name": generator,
3043
"conversion_func": conversion_func_register,
44+
"type_conversion_fn": constants_wrapper,
3145
**kwargs,
3246
}
3347
return pytorch_funcify(

pytensor/link/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,8 @@ def fgraph_to_python(
766766
global_env[local_input_name] = type_conversion_fn(
767767
input_storage[0], variable=i, storage=input_storage, **kwargs
768768
)
769-
# TODO: We could attempt to use the storage arrays directly
770-
# E.g. `local_input_name = f"{local_input_name}[0]"`
769+
# TODO: We could attempt to use the storage arrays directly
770+
# E.g. `local_input_name = f"{local_input_name}[0]"`
771771
node_input_names.append(local_input_name)
772772

773773
node_output_names = [unique_name(v) for v in node.outputs]

0 commit comments

Comments
 (0)