|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | import torch |
10 | | -from executorch.exir.pass_base import ExportPass, map_args |
| 10 | +from executorch.exir.pass_base import ExportPass, PassResult, map_args |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class ScalarToTensorPass(ExportPass): |
| 14 | + def __init__(self) -> None: |
| 15 | + super().__init__() |
| 16 | + self._modified = False |
| 17 | + |
14 | 18 | # pyre-ignore |
15 | 19 | def call_operator(self, op, args, kwargs, meta): |
16 | 20 | # pyre-ignore |
17 | 21 | def try_coerce(value, arg): |
18 | 22 | # Note: we want to create tensor constants instead of |
19 | | - # FakeTensor or ProxyTensor. If python_dispatcher is enabled, |
| 23 | + # FakeTensor or ProxyTensor. If python_dispatcher is enabled,was |
20 | 24 | # the fake_tensor_mode of inputs will be used so that we won't |
21 | 25 | # get a constant tensor with torch.tensor() call but instead |
22 | 26 | # a fake tensor is created. |
23 | 27 | with torch.utils._python_dispatch._disable_current_modes(): |
24 | | - return ( |
25 | | - torch.tensor(value) |
26 | | - if isinstance(value, (float, int, bool)) |
27 | | - and isinstance(arg.type, torch.TensorType) |
28 | | - else value |
| 28 | + should_coerce = isinstance(value, (float, int, bool)) and isinstance( |
| 29 | + arg.type, torch.TensorType |
29 | 30 | ) |
| 31 | + if should_coerce: |
| 32 | + self._modified = True |
| 33 | + return torch.tensor(value) |
| 34 | + return value |
30 | 35 |
|
31 | 36 | args, kwargs = map_args(op, try_coerce, args, kwargs) |
32 | 37 | return super().call_operator(op, args, kwargs, meta) |
| 38 | + |
| 39 | + # pyre-ignore |
| 40 | + def call(self, graph_module): |
| 41 | + |
| 42 | + self._modified = False |
| 43 | + result = super().call(graph_module) |
| 44 | + if result is not None: |
| 45 | + return PassResult(result.graph_module, self._modified) |
| 46 | + return result |
0 commit comments