Skip to content

Commit 96dfa9c

Browse files
authored
Add pybindings for bpte and ptd file
Differential Revision: D83518944 Pull Request resolved: #14678
1 parent 0cd8256 commit 96dfa9c

File tree

2 files changed

+103
-43
lines changed

2 files changed

+103
-43
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,24 @@ void setup_output_storage(
158158
}
159159
}
160160

161+
inline std::unique_ptr<DataLoader> loader_from_buffer(
162+
const void* ptr,
163+
size_t ptr_len) {
164+
return std::make_unique<BufferDataLoader>(ptr, ptr_len);
165+
}
166+
167+
inline std::unique_ptr<DataLoader> loader_from_file(const std::string& path) {
168+
Result<MmapDataLoader> res = MmapDataLoader::from(
169+
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
170+
THROW_IF_ERROR(
171+
res.error(),
172+
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
173+
path.c_str(),
174+
static_cast<uint32_t>(res.error()));
175+
176+
return std::make_unique<MmapDataLoader>(std::move(res.get()));
177+
}
178+
161179
inline std::unique_ptr<Module> load_module_from_buffer(
162180
const void* ptr,
163181
size_t ptr_len,
@@ -166,11 +184,11 @@ inline std::unique_ptr<Module> load_module_from_buffer(
166184
std::unique_ptr<runtime::EventTracer> event_tracer,
167185
Program::Verification program_verification) {
168186
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
169-
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
187+
auto loader = loader_from_buffer(ptr, ptr_len);
170188

171189
if (data_map_ptr.has_value() && data_map_len.has_value()) {
172-
auto data_map_loader = std::make_unique<BufferDataLoader>(
173-
data_map_ptr.value(), data_map_len.value());
190+
auto data_map_loader =
191+
loader_from_buffer(data_map_ptr.value(), data_map_len.value());
174192
return std::make_unique<Module>(
175193
std::move(loader),
176194
nullptr, // memory_allocator
@@ -194,27 +212,9 @@ inline std::unique_ptr<Module> load_module_from_file(
194212
Program::Verification program_verification) {
195213
EXECUTORCH_SCOPE_PROF("load_module_from_file");
196214

197-
Result<MmapDataLoader> program_loader_res = MmapDataLoader::from(
198-
program_path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
199-
THROW_IF_ERROR(
200-
program_loader_res.error(),
201-
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
202-
program_path.c_str(),
203-
static_cast<uint32_t>(program_loader_res.error()));
204-
auto program_loader =
205-
std::make_unique<MmapDataLoader>(std::move(program_loader_res.get()));
206-
215+
auto program_loader = loader_from_file(program_path);
207216
if (data_map_path.has_value()) {
208-
Result<MmapDataLoader> data_map_loader_res = MmapDataLoader::from(
209-
data_map_path->c_str(),
210-
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
211-
THROW_IF_ERROR(
212-
data_map_loader_res.error(),
213-
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
214-
data_map_path->c_str(),
215-
static_cast<uint32_t>(data_map_loader_res.error()));
216-
auto data_map_loader =
217-
std::make_unique<MmapDataLoader>(std::move(data_map_loader_res.get()));
217+
auto data_map_loader = loader_from_file(data_map_path.value());
218218
return std::make_unique<Module>(
219219
std::move(program_loader),
220220
nullptr, // memory_allocator
@@ -230,6 +230,22 @@ inline std::unique_ptr<Module> load_module_from_file(
230230
nullptr); // data_map_loader
231231
}
232232

233+
inline std::unique_ptr<Module> load_module_from_buffer_with_data_file(
234+
const void* ptr,
235+
size_t ptr_len,
236+
const std::string& data_map_path,
237+
std::unique_ptr<runtime::EventTracer> event_tracer,
238+
Program::Verification program_verification) {
239+
auto program_loader = loader_from_buffer(ptr, ptr_len);
240+
auto data_loader = loader_from_file(data_map_path);
241+
return std::make_unique<Module>(
242+
std::move(program_loader),
243+
nullptr, // memory_allocator
244+
nullptr, // temp_allocator
245+
std::move(event_tracer), // event_tracer
246+
std::move(data_loader));
247+
}
248+
233249
inline py::list get_outputs_as_py_list(
234250
const std::vector<EValue>& outputs,
235251
bool clone_outputs = true) {
@@ -555,6 +571,22 @@ struct PyModule final {
555571
setup_event_tracer(enable_etdump, debug_buffer_size),
556572
program_verification)) {}
557573

574+
explicit PyModule(
575+
const void* ptr,
576+
size_t ptr_len,
577+
const std::string& data_path,
578+
bool enable_etdump,
579+
size_t debug_buffer_size = 0,
580+
Program::Verification program_verification =
581+
Program::Verification::InternalConsistency)
582+
: debug_buffer_size_(debug_buffer_size),
583+
module_(load_module_from_buffer_with_data_file(
584+
ptr,
585+
ptr_len,
586+
data_path,
587+
setup_event_tracer(enable_etdump, debug_buffer_size),
588+
program_verification)) {}
589+
558590
explicit PyModule(
559591
const std::string& program_path,
560592
std::optional<const std::string>& data_path,
@@ -605,6 +637,7 @@ struct PyModule final {
605637
program_verification);
606638
}
607639

640+
// Load with data as a buffer.
608641
static std::unique_ptr<PyModule> load_from_bundled_program(
609642
PyBundledModule& m,
610643
std::optional<const py::bytes> data_map_buffer,
@@ -628,6 +661,21 @@ struct PyModule final {
628661
Program::Verification::InternalConsistency);
629662
}
630663

664+
// Load with data as a file.
665+
static std::unique_ptr<PyModule> load_from_bundled_program(
666+
PyBundledModule& m,
667+
const std::string& data_path,
668+
bool enable_etdump,
669+
size_t debug_buffer_size = 0) {
670+
return std::make_unique<PyModule>(
671+
m.get_program_ptr(),
672+
m.get_program_len(),
673+
data_path,
674+
enable_etdump,
675+
debug_buffer_size,
676+
Program::Verification::InternalConsistency);
677+
}
678+
631679
py::list run_method(
632680
const std::string& method_name,
633681
const py::sequence& inputs,
@@ -900,24 +948,6 @@ struct PyModule final {
900948
}
901949
};
902950

903-
inline std::unique_ptr<DataLoader> loader_from_buffer(
904-
const void* ptr,
905-
size_t ptr_len) {
906-
return std::make_unique<BufferDataLoader>(ptr, ptr_len);
907-
}
908-
909-
inline std::unique_ptr<DataLoader> loader_from_file(const std::string& path) {
910-
Result<MmapDataLoader> res = MmapDataLoader::from(
911-
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
912-
THROW_IF_ERROR(
913-
res.error(),
914-
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
915-
path.c_str(),
916-
static_cast<uint32_t>(res.error()));
917-
918-
return std::make_unique<MmapDataLoader>(std::move(res.get()));
919-
}
920-
921951
inline std::shared_ptr<ProgramState> load_program(
922952
std::unique_ptr<DataLoader> loader,
923953
Program::Verification program_verification) {
@@ -1474,12 +1504,25 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
14741504
call_guard);
14751505
m.def(
14761506
"_load_for_executorch_from_bundled_program",
1477-
&PyModule::load_from_bundled_program,
1507+
py::overload_cast<
1508+
PyBundledModule&,
1509+
std::optional<const py::bytes>,
1510+
bool,
1511+
size_t>(&PyModule::load_from_bundled_program),
14781512
py::arg("ptr"),
14791513
py::arg("data_map_buffer") = std::nullopt,
14801514
py::arg("enable_etdump") = false,
14811515
py::arg("debug_buffer_size") = 0,
14821516
call_guard);
1517+
m.def(
1518+
"_load_for_executorch_from_bundled_program",
1519+
py::overload_cast<PyBundledModule&, const std::string&, bool, size_t>(
1520+
&PyModule::load_from_bundled_program),
1521+
py::arg("ptr"),
1522+
py::arg("data_path"),
1523+
py::arg("enable_etdump") = false,
1524+
py::arg("debug_buffer_size") = 0,
1525+
call_guard);
14831526
m.def(
14841527
"_load_bundled_program_from_buffer",
14851528
&PyBundledModule::load_from_buffer,

extension/pybindings/test/test_pybindings.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ def test_program_data_separation(self) -> None:
701701
bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program)
702702
bundled_module = self.runtime._load_bundled_program_from_buffer(bundled_buffer)
703703

704-
# Load module from bundled program with external data
704+
# Load module from bundled program with external data buffer
705705
executorch_module_bundled = (
706706
self.runtime._load_for_executorch_from_bundled_program(
707707
bundled_module, data_buffer
@@ -710,6 +710,23 @@ def test_program_data_separation(self) -> None:
710710
executorch_output_bundled = executorch_module_bundled.forward(inputs)[0]
711711
self.assertTrue(torch.allclose(expected, executorch_output_bundled))
712712

713+
# Load module from bundled program with external data file
714+
with tempfile.TemporaryDirectory() as tmpdir:
715+
ptd_file = os.path.join(tmpdir, "linear.ptd")
716+
with open(ptd_file, "wb") as ptd:
717+
ptd.write(data_buffer)
718+
executorch_module_bundled_data_file = (
719+
self.runtime._load_for_executorch_from_bundled_program(
720+
bundled_module, ptd_file
721+
)
722+
)
723+
executorch_output_bundled_data_file = (
724+
executorch_module_bundled_data_file.forward(inputs)[0]
725+
)
726+
self.assertTrue(
727+
torch.allclose(expected, executorch_output_bundled_data_file)
728+
)
729+
713730
# Test 6: Bundled program without external data should fail
714731
executorch_module_bundled_no_data = (
715732
self.runtime._load_for_executorch_from_bundled_program(bundled_module)

0 commit comments

Comments
 (0)