Skip to content

Commit 5b4cb9a

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add program-data separation to pybindings (#13886)
Summary: Add support for optional data path for pybindings. Differential Revision: D76353209
1 parent 6c1ef96 commit 5b4cb9a

File tree

4 files changed

+104
-21
lines changed

4 files changed

+104
-21
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -173,27 +173,51 @@ inline std::unique_ptr<Module> load_module_from_buffer(
173173
}
174174

175175
inline std::unique_ptr<Module> load_module_from_file(
176-
const std::string& path,
176+
const std::string& program_path,
177+
std::optional<const std::string>& data_map_path,
177178
bool enable_etdump,
178179
size_t debug_buffer_size,
179180
Program::Verification program_verification) {
180181
EXECUTORCH_SCOPE_PROF("load_module_from_file");
181182

182-
Result<MmapDataLoader> res = MmapDataLoader::from(
183-
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
183+
Result<MmapDataLoader> program_loader_res = MmapDataLoader::from(
184+
program_path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
184185
THROW_IF_ERROR(
185-
res.error(),
186+
program_loader_res.error(),
186187
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
187-
path.c_str(),
188-
static_cast<uint32_t>(res.error()));
189-
190-
auto loader = std::make_unique<MmapDataLoader>(std::move(res.get()));
191-
return std::make_unique<Module>(
192-
std::move(loader),
193-
nullptr, // memory_allocator
194-
nullptr, // temp_allocator
195-
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
196-
nullptr); // data_map_loader
188+
program_path.c_str(),
189+
static_cast<uint32_t>(program_loader_res.error()));
190+
auto program_loader =
191+
std::make_unique<MmapDataLoader>(std::move(program_loader_res.get()));
192+
193+
if (data_map_path.has_value()) {
194+
Result<MmapDataLoader> data_map_loader_res = MmapDataLoader::from(
195+
data_map_path->c_str(),
196+
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
197+
THROW_IF_ERROR(
198+
data_map_loader_res.error(),
199+
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
200+
data_map_path->c_str(),
201+
static_cast<uint32_t>(data_map_loader_res.error()));
202+
auto data_map_loader =
203+
std::make_unique<MmapDataLoader>(std::move(data_map_loader_res.get()));
204+
205+
return std::make_unique<Module>(
206+
std::move(program_loader),
207+
nullptr, // memory_allocator
208+
nullptr, // temp_allocator
209+
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>()
210+
: nullptr,
211+
std::move(data_map_loader)); // data_map_loader
212+
} else {
213+
return std::make_unique<Module>(
214+
std::move(program_loader),
215+
nullptr, // memory_allocator
216+
nullptr, // temp_allocator
217+
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>()
218+
: nullptr,
219+
nullptr); // data_map_loader
220+
}
197221
}
198222

199223
inline py::list get_outputs_as_py_list(
@@ -495,13 +519,15 @@ struct PyModule final {
495519
program_verification)) {}
496520

497521
explicit PyModule(
498-
const std::string& path,
522+
const std::string& program_path,
523+
std::optional<const std::string>& data_path,
499524
bool enable_etdump,
500525
size_t debug_buffer_size = 0,
501526
Program::Verification program_verification =
502527
Program::Verification::InternalConsistency)
503528
: module_(load_module_from_file(
504-
path,
529+
program_path,
530+
data_path,
505531
enable_etdump,
506532
debug_buffer_size,
507533
program_verification)) {}
@@ -521,14 +547,20 @@ struct PyModule final {
521547
return std::make_unique<PyModule>(
522548
buffer, enable_etdump, debug_buffer_size, program_verification);
523549
}
550+
524551
static std::unique_ptr<PyModule> load_from_file(
525-
const std::string& path,
552+
const std::string& program_path,
553+
std::optional<const std::string>& data_path,
526554
bool enable_etdump,
527555
size_t debug_buffer_size = 0,
528556
Program::Verification program_verification =
529557
Program::Verification::InternalConsistency) {
530558
return std::make_unique<PyModule>(
531-
path, enable_etdump, debug_buffer_size, program_verification);
559+
program_path,
560+
data_path,
561+
enable_etdump,
562+
debug_buffer_size,
563+
program_verification);
532564
}
533565

534566
static std::unique_ptr<PyModule> load_from_bundled_program(
@@ -1301,7 +1333,8 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
13011333
m.def(
13021334
"_load_for_executorch",
13031335
PyModule::load_from_file,
1304-
py::arg("path"),
1336+
py::arg("program_path"),
1337+
py::arg("data_path") = std::nullopt,
13051338
py::arg("enable_etdump") = false,
13061339
py::arg("debug_buffer_size") = 0,
13071340
py::arg("program_verification") =

extension/pybindings/pybindings.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ class MethodMeta:
147147

148148
@experimental("This API is experimental and subject to change without notice.")
149149
def _load_for_executorch(
150-
path: str,
150+
program_path: str,
151+
data_path: Optional[str] = None,
151152
enable_etdump: bool = False,
152153
debug_buffer_size: int = 0,
153154
program_verification: Verification = Verification.InternalConsistency,
@@ -159,7 +160,8 @@ def _load_for_executorch(
159160
This API is experimental and subject to change without notice.
160161
161162
Args:
162-
path: File path to the ExecuTorch program as a string.
163+
program_path: File path to the ExecuTorch program as a string.
164+
data_path: File path to a .ptd file containing data used by the program.
163165
enable_etdump: If true, enables an ETDump which can store profiling information.
164166
See documentation at https://pytorch.org/executorch/main/etdump
165167
for how to use it.

extension/pybindings/test/make_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,21 @@ def get_inputs(self):
133133
return (torch.ones(2, 2), torch.ones(2, 2))
134134

135135

136+
class ModuleLinear(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.linear = torch.nn.Linear(3, 3)
140+
141+
def forward(self, x: torch.Tensor):
142+
return self.linear(x)
143+
144+
def get_methods_to_export(self):
145+
return ("forward",)
146+
147+
def get_inputs(self):
148+
return (torch.randn(3),)
149+
150+
136151
def create_program(
137152
eager_module: torch.nn.Module,
138153
et_config: Optional[ExecutorchBackendConfig] = None,

extension/pybindings/test/test_pybindings.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ModuleAddWithAttributes,
2323
ModuleChannelsLast,
2424
ModuleChannelsLastInDefaultOut,
25+
ModuleLinear,
2526
ModuleMulti,
2627
)
2728
from torch.export import export
@@ -600,3 +601,35 @@ def test_method_method_meta(self) -> None:
600601
self.assertEqual(output_tensor.is_memory_planned(), True)
601602
self.assertEqual(output_tensor.nbytes(), 16)
602603
self.assertEqual(str(output_tensor), tensor_info)
604+
605+
def test_program_data_separation(self) -> None:
606+
eager_module = ModuleLinear()
607+
inputs = eager_module.get_inputs()
608+
exported_program = export(eager_module, inputs, strict=True)
609+
exec_program = to_edge(exported_program).to_executorch(
610+
config=ExecutorchBackendConfig(
611+
# Move all tensor data to '_default_external_constant' file.
612+
external_constants=True,
613+
)
614+
)
615+
616+
import os
617+
import tempfile
618+
619+
with tempfile.TemporaryDirectory() as tmpdir:
620+
pte_file = os.path.join(tmpdir, "linear.pte")
621+
with open(pte_file, "wb") as f:
622+
f.write(exec_program.buffer)
623+
624+
ptd_file = os.path.join(tmpdir, "linear.ptd")
625+
with open(ptd_file, "wb") as ptd:
626+
tensor_data = bytes(
627+
exec_program._tensor_data.pop("_default_external_constant")
628+
)
629+
ptd.write(tensor_data)
630+
631+
executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file)
632+
633+
expected = eager_module(inputs[0])
634+
executorch_output = executorch_program.forward(inputs)[0]
635+
self.assertTrue(torch.allclose(expected, executorch_output))

0 commit comments

Comments
 (0)