1818from torch .export .experimental import _export_forward_backward
1919
2020
21+ def _export_model ():
22+ net = TrainingNet (Net ())
23+ x = torch .randn (1 , 2 )
24+
25+ # Captures the forward graph. The graph will look similar to the model definition now.
26+ # Will move to export_for_training soon which is the api planned to be supported in the long term.
27+ ep = export (net , (x , torch .ones (1 , dtype = torch .int64 )))
28+ # Captures the backward graph. The exported_program now contains the joint forward and backward graph.
29+ ep = _export_forward_backward (ep )
30+ # Lower the graph to edge dialect.
31+ ep = to_edge (ep )
32+ # Lower the graph to executorch.
33+ ep = ep .to_executorch ()
34+
35+
2136def main () -> None :
2237 torch .manual_seed (0 )
2338 parser = argparse .ArgumentParser (
@@ -32,18 +47,7 @@ def main() -> None:
3247 )
3348 args = parser .parse_args ()
3449
35- net = TrainingNet (Net ())
36- x = torch .randn (1 , 2 )
37-
38- # Captures the forward graph. The graph will look similar to the model definition now.
39- # Will move to export_for_training soon which is the api planned to be supported in the long term.
40- ep = export (net , (x , torch .ones (1 , dtype = torch .int64 )))
41- # Captures the backward graph. The exported_program now contains the joint forward and backward graph.
42- ep = _export_forward_backward (ep )
43- # Lower the graph to edge dialect.
44- ep = to_edge (ep )
45- # Lower the graph to executorch.
46- ep = ep .to_executorch ()
50+ ep = _export_model ()
4751
4852 # Write out the .pte file.
4953 os .makedirs (args .outdir , exist_ok = True )
0 commit comments