Skip to content

Commit 2ee4772

Browse files
Update on "[ExecuTorch][#10447] Extend PyBundledModule with extension.BundledModule"
#10447 # Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. # Proposal Now that we have `extension.BundledModule` ready, we want to test it out by having our existing `PyBundledModule` to extend it, and let `verify_result_with_bundled_expected_output` to use it, so that we can test out the whole thing with https://github.com/pytorch/executorch/blob/fb45e19055a92d2a91a4d4b7008e135232cbb14b/devtools/bundled_program/test/test_end2end.py Differential Revision: [D73564127](https://our.internmc.facebook.com/intern/diff/D73564127/) [ghstack-poisoned]
2 parents 5d0c373 + 3526512 commit 2ee4772

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

extension/module/test/resources/gen_bundled_program.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2-
31
import torch
42

5-
from executorch.exir import to_edge_transform_and_lower
63
from executorch.devtools import BundledProgram
74

85
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
96
from executorch.devtools.bundled_program.serialize import (
107
serialize_from_bundled_program_to_flatbuffer,
118
)
9+
10+
from executorch.exir import to_edge_transform_and_lower
1211
from torch.export import export, export_for_training
1312

1413
# Step 1: ExecuTorch Program Export
@@ -17,8 +16,8 @@ class SampleModel(torch.nn.Module):
1716

1817
def __init__(self) -> None:
1918
super().__init__()
20-
self.register_buffer('a', 3 * torch.ones(2, 2, dtype=torch.int32))
21-
self.register_buffer('b', 2 * torch.ones(2, 2, dtype=torch.int32))
19+
self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32))
20+
self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32))
2221

2322
def forward(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
2423
z = x.clone()
@@ -76,7 +75,7 @@ def main() -> None:
7675
test_cases=[
7776
MethodTestCase(
7877
inputs=input,
79-
expected_outputs=(getattr(model, method_name)(*input), ),
78+
expected_outputs=(getattr(model, method_name)(*input),),
8079
)
8180
for input in inputs
8281
],

0 commit comments

Comments
 (0)