Skip to content

Commit cc6cb83

Browse files
authored
Add option to specify fake tensor mode for graph and program builders.
Differential Revision: D84187909 Pull Request resolved: #14958
1 parent 1dc0e0e commit cc6cb83

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

backends/cadence/aot/graph_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ class GraphBuilder(ExportPass):
4444
gm = builder.get_graph_module()
4545
"""
4646

47-
def __init__(self) -> None:
47+
def __init__(self, fake_tensor_mode: Optional[FakeTensorMode] = None) -> None:
4848
self.exporter = ExportPass()
4949
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
5050
self, torch.fx.graph.CodeGen()
5151
)
52-
self.fake_tensor_mode = FakeTensorMode(
52+
self.fake_tensor_mode: FakeTensorMode = fake_tensor_mode or FakeTensorMode(
5353
allow_fallback_kernels=False,
5454
allow_non_fake_inputs=True,
5555
)

backends/cadence/aot/program_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import Tensor
1313
from torch._export.verifier import Verifier
1414
from torch._ops import OpOverload
15+
from torch._subclasses.fake_tensor import FakeTensorMode
1516
from torch.export import ExportedProgram
1617
from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature
1718
from torch.export.graph_signature import (
@@ -37,6 +38,7 @@ def __init__(
3738
self,
3839
mode: Optional[IrMode] = None,
3940
_core_aten_ops_exception_list: Optional[list[OpOverload]] = None,
41+
fake_tensor_mode: Optional[FakeTensorMode] = None,
4042
) -> None:
4143
self.input_specs: list[InputSpec] = []
4244
self.output_specs: list[OutputSpec] = []
@@ -46,7 +48,7 @@ def __init__(
4648
self._core_aten_ops_exception_list: list[OpOverload] = (
4749
_core_aten_ops_exception_list or []
4850
)
49-
super().__init__()
51+
super().__init__(fake_tensor_mode=fake_tensor_mode)
5052

5153
def insert_input_spec(
5254
self, target: str, input_kind: InputKind, value: Tensor

0 commit comments

Comments
 (0)