|
12 | 12 | from torch import Tensor
|
13 | 13 | from torch._export.verifier import Verifier
|
14 | 14 | from torch.export import ExportedProgram
|
| 15 | +from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature |
15 | 16 | from torch.export.graph_signature import (
|
16 | 17 | ExportGraphSignature,
|
17 | 18 | InputKind,
|
|
20 | 21 | OutputSpec,
|
21 | 22 | TensorArgument,
|
22 | 23 | )
|
| 24 | +from torch.utils import _pytree as pytree |
23 | 25 |
|
24 | 26 |
|
25 | 27 | class IrMode(Enum):
|
@@ -87,17 +89,25 @@ def get_verifiers(self) -> Optional[list[Verifier]]:
|
87 | 89 |
|
88 | 90 | def get_program(self) -> ExportedProgram:
|
89 | 91 | gm = self.get_graph_module()
|
| 92 | + graph_signature = ExportGraphSignature(self.input_specs, self.output_specs) |
| 93 | + in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1] |
| 94 | + out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1] |
90 | 95 | return ExportedProgram(
|
91 | 96 | root=gm,
|
92 | 97 | graph=gm.graph,
|
93 |
| - graph_signature=ExportGraphSignature( |
94 |
| - input_specs=self.input_specs, output_specs=self.output_specs |
95 |
| - ), |
| 98 | + graph_signature=graph_signature, |
96 | 99 | # pyre-ignore[6]: Incompatible parameter type.
|
97 | 100 | constants=self.constants,
|
98 | 101 | state_dict=self.state_dict,
|
99 | 102 | range_constraints={},
|
100 |
| - module_call_graph=[], |
| 103 | + module_call_graph=[ |
| 104 | + ModuleCallEntry( |
| 105 | + "", |
| 106 | + ModuleCallSignature( |
| 107 | + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec |
| 108 | + ), |
| 109 | + ) |
| 110 | + ], |
101 | 111 | # pyre-ignore[6]: Incompatible parameter type.
|
102 | 112 | verifiers=self.get_verifiers(),
|
103 | 113 | )
|
|
0 commit comments