Skip to content

Commit 6c12956

Browse files
authored
Add program-data separation to pybindings
Differential Revision: D76353209 Pull Request resolved: #13886
1 parent fbda3a9 commit 6c12956

File tree

4 files changed

+95
-16
lines changed

4 files changed

+95
-16
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,42 @@ inline std::unique_ptr<Module> load_module_from_buffer(
174174
}
175175

176176
inline std::unique_ptr<Module> load_module_from_file(
177-
const std::string& path,
177+
const std::string& program_path,
178+
std::optional<const std::string>& data_map_path,
178179
std::unique_ptr<runtime::EventTracer> event_tracer,
179180
Program::Verification program_verification) {
180181
EXECUTORCH_SCOPE_PROF("load_module_from_file");
181182

182-
Result<MmapDataLoader> res = MmapDataLoader::from(
183-
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
183+
Result<MmapDataLoader> program_loader_res = MmapDataLoader::from(
184+
program_path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
184185
THROW_IF_ERROR(
185-
res.error(),
186+
program_loader_res.error(),
186187
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
187-
path.c_str(),
188-
static_cast<uint32_t>(res.error()));
189-
190-
auto loader = std::make_unique<MmapDataLoader>(std::move(res.get()));
188+
program_path.c_str(),
189+
static_cast<uint32_t>(program_loader_res.error()));
190+
auto program_loader =
191+
std::make_unique<MmapDataLoader>(std::move(program_loader_res.get()));
192+
193+
if (data_map_path.has_value()) {
194+
Result<MmapDataLoader> data_map_loader_res = MmapDataLoader::from(
195+
data_map_path->c_str(),
196+
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
197+
THROW_IF_ERROR(
198+
data_map_loader_res.error(),
199+
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
200+
data_map_path->c_str(),
201+
static_cast<uint32_t>(data_map_loader_res.error()));
202+
auto data_map_loader =
203+
std::make_unique<MmapDataLoader>(std::move(data_map_loader_res.get()));
204+
return std::make_unique<Module>(
205+
std::move(program_loader),
206+
nullptr, // memory_allocator
207+
nullptr, // temp_allocator
208+
std::move(event_tracer), // event_tracer
209+
std::move(data_map_loader)); // data_map_loader
210+
}
191211
return std::make_unique<Module>(
192-
std::move(loader),
212+
std::move(program_loader),
193213
nullptr, // memory_allocator
194214
nullptr, // temp_allocator
195215
std::move(event_tracer), // event_tracer
@@ -510,14 +530,16 @@ struct PyModule final {
510530
program_verification)) {}
511531

512532
explicit PyModule(
513-
const std::string& path,
533+
const std::string& program_path,
534+
std::optional<const std::string>& data_path,
514535
bool enable_etdump,
515536
size_t debug_buffer_size = 0,
516537
Program::Verification program_verification =
517538
Program::Verification::InternalConsistency)
518539
: debug_buffer_size_(debug_buffer_size),
519540
module_(load_module_from_file(
520-
path,
541+
program_path,
542+
data_path,
521543
setup_event_tracer(enable_etdump, debug_buffer_size),
522544
program_verification)) {}
523545

@@ -536,14 +558,20 @@ struct PyModule final {
536558
return std::make_unique<PyModule>(
537559
buffer, enable_etdump, debug_buffer_size, program_verification);
538560
}
561+
539562
static std::unique_ptr<PyModule> load_from_file(
540-
const std::string& path,
563+
const std::string& program_path,
564+
std::optional<const std::string>& data_path,
541565
bool enable_etdump,
542566
size_t debug_buffer_size = 0,
543567
Program::Verification program_verification =
544568
Program::Verification::InternalConsistency) {
545569
return std::make_unique<PyModule>(
546-
path, enable_etdump, debug_buffer_size, program_verification);
570+
program_path,
571+
data_path,
572+
enable_etdump,
573+
debug_buffer_size,
574+
program_verification);
547575
}
548576

549577
static std::unique_ptr<PyModule> load_from_bundled_program(
@@ -1351,7 +1379,8 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
13511379
m.def(
13521380
"_load_for_executorch",
13531381
PyModule::load_from_file,
1354-
py::arg("path"),
1382+
py::arg("program_path"),
1383+
py::arg("data_path") = std::nullopt,
13551384
py::arg("enable_etdump") = false,
13561385
py::arg("debug_buffer_size") = 0,
13571386
py::arg("program_verification") =

extension/pybindings/pybindings.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ class MethodMeta:
156156

157157
@experimental("This API is experimental and subject to change without notice.")
158158
def _load_for_executorch(
159-
path: str,
159+
program_path: str,
160+
data_path: Optional[str] = None,
160161
enable_etdump: bool = False,
161162
debug_buffer_size: int = 0,
162163
program_verification: Verification = Verification.InternalConsistency,
@@ -168,7 +169,8 @@ def _load_for_executorch(
168169
This API is experimental and subject to change without notice.
169170
170171
Args:
171-
path: File path to the ExecuTorch program as a string.
172+
program_path: File path to the ExecuTorch program as a string.
173+
data_path: File path to a .ptd file containing data used by the program.
172174
enable_etdump: If true, enables an ETDump which can store profiling information.
173175
See documentation at https://pytorch.org/executorch/main/etdump
174176
for how to use it.

extension/pybindings/test/make_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,21 @@ def get_inputs(self):
133133
return (torch.ones(2, 2), torch.ones(2, 2))
134134

135135

136+
class ModuleLinear(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.linear = torch.nn.Linear(3, 3)
140+
141+
def forward(self, x: torch.Tensor):
142+
return self.linear(x)
143+
144+
def get_methods_to_export(self):
145+
return ("forward",)
146+
147+
def get_inputs(self):
148+
return (torch.randn(3),)
149+
150+
136151
def create_program(
137152
eager_module: torch.nn.Module,
138153
et_config: Optional[ExecutorchBackendConfig] = None,

extension/pybindings/test/test_pybindings.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ModuleAddWithAttributes,
2323
ModuleChannelsLast,
2424
ModuleChannelsLastInDefaultOut,
25+
ModuleLinear,
2526
ModuleMulti,
2627
)
2728
from torch.export import export
@@ -623,3 +624,35 @@ def test_method_method_meta(self) -> None:
623624
self.assertEqual(output_tensor.is_memory_planned(), True)
624625
self.assertEqual(output_tensor.nbytes(), 16)
625626
self.assertEqual(str(output_tensor), tensor_info)
627+
628+
def test_program_data_separation(self) -> None:
629+
eager_module = ModuleLinear()
630+
inputs = eager_module.get_inputs()
631+
exported_program = export(eager_module, inputs, strict=True)
632+
exec_program = to_edge(exported_program).to_executorch(
633+
config=ExecutorchBackendConfig(
634+
# Move all tensor data to '_default_external_constant' file.
635+
external_constants=True,
636+
)
637+
)
638+
639+
import os
640+
import tempfile
641+
642+
with tempfile.TemporaryDirectory() as tmpdir:
643+
pte_file = os.path.join(tmpdir, "linear.pte")
644+
with open(pte_file, "wb") as f:
645+
f.write(exec_program.buffer)
646+
647+
ptd_file = os.path.join(tmpdir, "linear.ptd")
648+
with open(ptd_file, "wb") as ptd:
649+
tensor_data = bytes(
650+
exec_program._tensor_data.pop("_default_external_constant")
651+
)
652+
ptd.write(tensor_data)
653+
654+
executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file)
655+
656+
expected = eager_module(inputs[0])
657+
executorch_output = executorch_program.forward(inputs)[0]
658+
self.assertTrue(torch.allclose(expected, executorch_output))

0 commit comments

Comments
 (0)