|
12 | 12 | from torch import Tensor |
13 | 13 | from torch._export.verifier import Verifier |
14 | 14 | from torch.export import ExportedProgram |
| 15 | +from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature |
15 | 16 | from torch.export.graph_signature import ( |
16 | 17 | ExportGraphSignature, |
17 | 18 | InputKind, |
|
20 | 21 | OutputSpec, |
21 | 22 | TensorArgument, |
22 | 23 | ) |
| 24 | +from torch.utils import _pytree as pytree |
23 | 25 |
|
24 | 26 |
|
25 | 27 | class IrMode(Enum): |
@@ -87,17 +89,25 @@ def get_verifiers(self) -> Optional[list[Verifier]]: |
87 | 89 |
|
88 | 90 | def get_program(self) -> ExportedProgram: |
89 | 91 | gm = self.get_graph_module() |
| 92 | + graph_signature = ExportGraphSignature(self.input_specs, self.output_specs) |
| 93 | + in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1] |
| 94 | + out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1] |
90 | 95 | return ExportedProgram( |
91 | 96 | root=gm, |
92 | 97 | graph=gm.graph, |
93 | | - graph_signature=ExportGraphSignature( |
94 | | - input_specs=self.input_specs, output_specs=self.output_specs |
95 | | - ), |
| 98 | + graph_signature=graph_signature, |
96 | 99 | # pyre-ignore[6]: Incompatible parameter type. |
97 | 100 | constants=self.constants, |
98 | 101 | state_dict=self.state_dict, |
99 | 102 | range_constraints={}, |
100 | | - module_call_graph=[], |
| 103 | + module_call_graph=[ |
| 104 | + ModuleCallEntry( |
| 105 | + "", |
| 106 | + ModuleCallSignature( |
| 107 | + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec |
| 108 | + ), |
| 109 | + ) |
| 110 | + ], |
101 | 111 | # pyre-ignore[6]: Incompatible parameter type. |
102 | 112 | verifiers=self.get_verifiers(), |
103 | 113 | ) |
|
0 commit comments