diff --git a/extension/training/examples/XOR/export_model.py b/extension/training/examples/XOR/export_model.py index 3089cea211e..c2cff7d4284 100644 --- a/extension/training/examples/XOR/export_model.py +++ b/extension/training/examples/XOR/export_model.py @@ -18,6 +18,21 @@ from torch.export.experimental import _export_forward_backward +def _export_model(): + net = TrainingNet(Net()) + x = torch.randn(1, 2) + + # Captures the forward graph. The graph will look similar to the model definition now. + # Will move to export_for_training soon which is the api planned to be supported in the long term. + ep = export(net, (x, torch.ones(1, dtype=torch.int64))) + # Captures the backward graph. The exported_program now contains the joint forward and backward graph. + ep = _export_forward_backward(ep) + # Lower the graph to edge dialect. + ep = to_edge(ep) + # Lower the graph to executorch. + ep = ep.to_executorch() + + def main() -> None: torch.manual_seed(0) parser = argparse.ArgumentParser( @@ -32,18 +47,7 @@ def main() -> None: ) args = parser.parse_args() - net = TrainingNet(Net()) - x = torch.randn(1, 2) - - # Captures the forward graph. The graph will look similar to the model definition now. - # Will move to export_for_training soon which is the api planned to be supported in the long term. - ep = export(net, (x, torch.ones(1, dtype=torch.int64))) - # Captures the backward graph. The exported_program now contains the joint forward and backward graph. - ep = _export_forward_backward(ep) - # Lower the graph to edge dialect. - ep = to_edge(ep) - # Lower the graph to executorch. - ep = ep.to_executorch() + ep = _export_model() # Write out the .pte file. os.makedirs(args.outdir, exist_ok=True) diff --git a/extension/training/examples/XOR/targets.bzl b/extension/training/examples/XOR/targets.bzl index ccd7f4bf6f8..26d0f40d90b 100644 --- a/extension/training/examples/XOR/targets.bzl +++ b/extension/training/examples/XOR/targets.bzl @@ -34,7 +34,7 @@ def define_common_targets(): runtime.python_library( name = "export_model_lib", srcs = ["export_model.py"], - visibility = [], + visibility = ["//executorch/extension/training/examples/XOR/..."], deps = [ ":model", "//caffe2:torch", diff --git a/extension/training/examples/XOR/test/TARGETS b/extension/training/examples/XOR/test/TARGETS new file mode 100644 index 00000000000..5c74b50efd8 --- /dev/null +++ b/extension/training/examples/XOR/test/TARGETS @@ -0,0 +1,15 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_test( + name = "test", + srcs = ["test_export.py"], + visibility = ["//executorch/extension/training/examples/XOR/test/..."], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/extension/training:lib", + "//executorch/extension/training/examples/XOR:export_model_lib", + ], +) diff --git a/extension/training/examples/XOR/test/test_export.py b/extension/training/examples/XOR/test/test_export.py new file mode 100644 index 00000000000..26a24607d9e --- /dev/null +++ b/extension/training/examples/XOR/test/test_export.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +from executorch.extension.training.examples.XOR.export_model import _export_model + + +class TestXORExport(unittest.TestCase): + def test(self): + _ = _export_model() + # Expect that we reach this far without an exception being thrown. + self.assertTrue(True)