Skip to content

Commit 2a71451

Browse files
Update on "[ExecuTorch][#10375] Add extension.BundledModule to Wrap extension.Module with Bundled Program Logic"
#10375 # Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. The issue is that `pybindings.PyModule` is dependent on the `method` getter from `pybindings.Module`, which `extension.Module` do not have. Since we don't want to expose `method` getter in `extension.Module`, we have to protect the getter, wrap the functions that is dependent on it and use the protected getter there, ultimately decouple `pybindings` from a `method` getter. # Proposal Now that we have a protected `method` getter, we can introduce a `extension.BundledModule`, a child class inheriting `extension.Module` which wraps up bundled program logic that is dependent on the `method` getter. Differential Revision: [D73564125](https://our.internmc.facebook.com/intern/diff/D73564125/) [ghstack-poisoned]
2 parents ec2c131 + ecd05de commit 2a71451

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

extension/module/test/resources/gen_bundled_program.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2-
31
import torch
42

5-
from executorch.exir import to_edge_transform_and_lower
63
from executorch.devtools import BundledProgram
74

85
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
96
from executorch.devtools.bundled_program.serialize import (
107
serialize_from_bundled_program_to_flatbuffer,
118
)
9+
10+
from executorch.exir import to_edge_transform_and_lower
1211
from torch.export import export, export_for_training
1312

1413
# Step 1: ExecuTorch Program Export
@@ -17,8 +16,8 @@ class SampleModel(torch.nn.Module):
1716

1817
def __init__(self) -> None:
1918
super().__init__()
20-
self.register_buffer('a', 3 * torch.ones(2, 2, dtype=torch.int32))
21-
self.register_buffer('b', 2 * torch.ones(2, 2, dtype=torch.int32))
19+
self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32))
20+
self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32))
2221

2322
def forward(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
2423
z = x.clone()
@@ -76,7 +75,7 @@ def main() -> None:
7675
test_cases=[
7776
MethodTestCase(
7877
inputs=input,
79-
expected_outputs=(getattr(model, method_name)(*input), ),
78+
expected_outputs=(getattr(model, method_name)(*input),),
8079
)
8180
for input in inputs
8281
],

0 commit comments

Comments
 (0)