Skip to content

Commit 59cbddc

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Ensure ScalarToTensorPass returns correct modified value.
Differential Revision: D88215978
1 parent a5b090f commit 59cbddc

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

exir/passes/scalar_to_tensor_pass.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,40 @@
77
# pyre-strict
88

99
import torch
10-
from executorch.exir.pass_base import ExportPass, map_args
10+
from executorch.exir.pass_base import ExportPass, PassResult, map_args
1111

1212

1313
class ScalarToTensorPass(ExportPass):
14+
def __init__(self) -> None:
15+
super().__init__()
16+
self._modified = False
17+
1418
# pyre-ignore
1519
def call_operator(self, op, args, kwargs, meta):
1620
# pyre-ignore
1721
def try_coerce(value, arg):
1822
# Note: we want to create tensor constants instead of
19-
# FakeTensor or ProxyTensor. If python_dispatcher is enabled,
23+
# FakeTensor or ProxyTensor. If python_dispatcher is enabled,was
2024
# the fake_tensor_mode of inputs will be used so that we won't
2125
# get a constant tensor with torch.tensor() call but instead
2226
# a fake tensor is created.
2327
with torch.utils._python_dispatch._disable_current_modes():
24-
return (
25-
torch.tensor(value)
26-
if isinstance(value, (float, int, bool))
27-
and isinstance(arg.type, torch.TensorType)
28-
else value
28+
should_coerce = isinstance(value, (float, int, bool)) and isinstance(
29+
arg.type, torch.TensorType
2930
)
31+
if should_coerce:
32+
self._modified = True
33+
return torch.tensor(value)
34+
return value
3035

3136
args, kwargs = map_args(op, try_coerce, args, kwargs)
3237
return super().call_operator(op, args, kwargs, meta)
38+
39+
# pyre-ignore
40+
def call(self, graph_module):
41+
42+
self._modified = False
43+
result = super().call(graph_module)
44+
if result is not None:
45+
return PassResult(result.graph_module, self._modified)
46+
return result

0 commit comments

Comments
 (0)