Skip to content

Commit bbe8943

Browse files
authored
Update module call graph for export program builder.
Differential Revision: D77341475 Pull Request resolved: #14038
1 parent b90c743 commit bbe8943

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

backends/cadence/aot/program_builder.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import Tensor
1313
from torch._export.verifier import Verifier
1414
from torch.export import ExportedProgram
15+
from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature
1516
from torch.export.graph_signature import (
1617
ExportGraphSignature,
1718
InputKind,
@@ -20,6 +21,7 @@
2021
OutputSpec,
2122
TensorArgument,
2223
)
24+
from torch.utils import _pytree as pytree
2325

2426

2527
class IrMode(Enum):
@@ -87,17 +89,25 @@ def get_verifiers(self) -> Optional[list[Verifier]]:
8789

8890
def get_program(self) -> ExportedProgram:
8991
gm = self.get_graph_module()
92+
graph_signature = ExportGraphSignature(self.input_specs, self.output_specs)
93+
in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1]
94+
out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1]
9095
return ExportedProgram(
9196
root=gm,
9297
graph=gm.graph,
93-
graph_signature=ExportGraphSignature(
94-
input_specs=self.input_specs, output_specs=self.output_specs
95-
),
98+
graph_signature=graph_signature,
9699
# pyre-ignore[6]: Incompatible parameter type.
97100
constants=self.constants,
98101
state_dict=self.state_dict,
99102
range_constraints={},
100-
module_call_graph=[],
103+
module_call_graph=[
104+
ModuleCallEntry(
105+
"",
106+
ModuleCallSignature(
107+
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
108+
),
109+
)
110+
],
101111
# pyre-ignore[6]: Incompatible parameter type.
102112
verifiers=self.get_verifiers(),
103113
)

0 commit comments

Comments
 (0)