From 56c838cf478f999848b38051783344ac50583d59 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Mon, 29 Sep 2025 17:13:57 -0700 Subject: [PATCH] Add pybindings for bpte and ptd file Summary: 1. Overload `_load_for_executorch_from_bundled_program` to take in ptd file as buffer or file 2. Refactor data loader cases to use the utility functions `loader_from_buffer`, `loader_from_file` Functionality used by sas compiler team Differential Revision: D83518944 --- extension/pybindings/pybindings.cpp | 127 +++++++++++++------ extension/pybindings/test/test_pybindings.py | 19 ++- 2 files changed, 103 insertions(+), 43 deletions(-) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index a896a4bde36..c3cd4ed0b47 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -158,6 +158,24 @@ void setup_output_storage( } } +inline std::unique_ptr loader_from_buffer( + const void* ptr, + size_t ptr_len) { + return std::make_unique(ptr, ptr_len); +} + +inline std::unique_ptr loader_from_file(const std::string& path) { + Result res = MmapDataLoader::from( + path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); + THROW_IF_ERROR( + res.error(), + "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, + path.c_str(), + static_cast(res.error())); + + return std::make_unique(std::move(res.get())); +} + inline std::unique_ptr load_module_from_buffer( const void* ptr, size_t ptr_len, @@ -166,11 +184,11 @@ inline std::unique_ptr load_module_from_buffer( std::unique_ptr event_tracer, Program::Verification program_verification) { EXECUTORCH_SCOPE_PROF("load_module_from_buffer"); - auto loader = std::make_unique(ptr, ptr_len); + auto loader = loader_from_buffer(ptr, ptr_len); if (data_map_ptr.has_value() && data_map_len.has_value()) { - auto data_map_loader = std::make_unique( - data_map_ptr.value(), data_map_len.value()); + auto data_map_loader = + loader_from_buffer(data_map_ptr.value(), data_map_len.value()); return std::make_unique( std::move(loader), nullptr, // memory_allocator @@ -194,27 +212,9 @@ inline std::unique_ptr load_module_from_file( Program::Verification program_verification) { EXECUTORCH_SCOPE_PROF("load_module_from_file"); - Result program_loader_res = MmapDataLoader::from( - program_path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); - THROW_IF_ERROR( - program_loader_res.error(), - "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, - program_path.c_str(), - static_cast(program_loader_res.error())); - auto program_loader = - std::make_unique(std::move(program_loader_res.get())); - + auto program_loader = loader_from_file(program_path); if (data_map_path.has_value()) { - Result data_map_loader_res = MmapDataLoader::from( - data_map_path->c_str(), - MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); - THROW_IF_ERROR( - data_map_loader_res.error(), - "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, - data_map_path->c_str(), - static_cast(data_map_loader_res.error())); - auto data_map_loader = - std::make_unique(std::move(data_map_loader_res.get())); + auto data_map_loader = loader_from_file(data_map_path.value()); return std::make_unique( std::move(program_loader), nullptr, // memory_allocator @@ -230,6 +230,22 @@ inline std::unique_ptr load_module_from_file( nullptr); // data_map_loader } +inline std::unique_ptr load_module_from_buffer_with_data_file( + const void* ptr, + size_t ptr_len, + const std::string& data_map_path, + std::unique_ptr event_tracer, + Program::Verification program_verification) { + auto program_loader = loader_from_buffer(ptr, ptr_len); + auto data_loader = loader_from_file(data_map_path); + return std::make_unique( + std::move(program_loader), + nullptr, // memory_allocator + nullptr, // temp_allocator + std::move(event_tracer), // event_tracer + std::move(data_loader)); +} + inline py::list get_outputs_as_py_list( const std::vector& outputs, bool clone_outputs = true) { @@ -555,6 +571,22 @@ struct PyModule final { setup_event_tracer(enable_etdump, debug_buffer_size), program_verification)) {} + explicit PyModule( + const void* ptr, + size_t ptr_len, + const std::string& data_path, + bool enable_etdump, + size_t debug_buffer_size = 0, + Program::Verification program_verification = + Program::Verification::InternalConsistency) + : debug_buffer_size_(debug_buffer_size), + module_(load_module_from_buffer_with_data_file( + ptr, + ptr_len, + data_path, + setup_event_tracer(enable_etdump, debug_buffer_size), + program_verification)) {} + explicit PyModule( const std::string& program_path, std::optional& data_path, @@ -605,6 +637,7 @@ struct PyModule final { program_verification); } + // Load with data as a buffer. static std::unique_ptr load_from_bundled_program( PyBundledModule& m, std::optional data_map_buffer, @@ -628,6 +661,21 @@ struct PyModule final { Program::Verification::InternalConsistency); } + // Load with data as a file. + static std::unique_ptr load_from_bundled_program( + PyBundledModule& m, + const std::string& data_path, + bool enable_etdump, + size_t debug_buffer_size = 0) { + return std::make_unique( + m.get_program_ptr(), + m.get_program_len(), + data_path, + enable_etdump, + debug_buffer_size, + Program::Verification::InternalConsistency); + } + py::list run_method( const std::string& method_name, const py::sequence& inputs, @@ -900,24 +948,6 @@ struct PyModule final { } }; -inline std::unique_ptr loader_from_buffer( - const void* ptr, - size_t ptr_len) { - return std::make_unique(ptr, ptr_len); -} - -inline std::unique_ptr loader_from_file(const std::string& path) { - Result res = MmapDataLoader::from( - path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); - THROW_IF_ERROR( - res.error(), - "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, - path.c_str(), - static_cast(res.error())); - - return std::make_unique(std::move(res.get())); -} - inline std::shared_ptr load_program( std::unique_ptr loader, Program::Verification program_verification) { @@ -1474,12 +1504,25 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { call_guard); m.def( "_load_for_executorch_from_bundled_program", - &PyModule::load_from_bundled_program, + py::overload_cast< + PyBundledModule&, + std::optional, + bool, + size_t>(&PyModule::load_from_bundled_program), py::arg("ptr"), py::arg("data_map_buffer") = std::nullopt, py::arg("enable_etdump") = false, py::arg("debug_buffer_size") = 0, call_guard); + m.def( + "_load_for_executorch_from_bundled_program", + py::overload_cast( + &PyModule::load_from_bundled_program), + py::arg("ptr"), + py::arg("data_path"), + py::arg("enable_etdump") = false, + py::arg("debug_buffer_size") = 0, + call_guard); m.def( "_load_bundled_program_from_buffer", &PyBundledModule::load_from_buffer, diff --git a/extension/pybindings/test/test_pybindings.py b/extension/pybindings/test/test_pybindings.py index 02ad6b5e327..ec45428c7d7 100644 --- a/extension/pybindings/test/test_pybindings.py +++ b/extension/pybindings/test/test_pybindings.py @@ -701,7 +701,7 @@ def test_program_data_separation(self) -> None: bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program) bundled_module = self.runtime._load_bundled_program_from_buffer(bundled_buffer) - # Load module from bundled program with external data + # Load module from bundled program with external data buffer executorch_module_bundled = ( self.runtime._load_for_executorch_from_bundled_program( bundled_module, data_buffer @@ -710,6 +710,23 @@ def test_program_data_separation(self) -> None: executorch_output_bundled = executorch_module_bundled.forward(inputs)[0] self.assertTrue(torch.allclose(expected, executorch_output_bundled)) + # Load module from bundled program with external data file + with tempfile.TemporaryDirectory() as tmpdir: + ptd_file = os.path.join(tmpdir, "linear.ptd") + with open(ptd_file, "wb") as ptd: + ptd.write(data_buffer) + executorch_module_bundled_data_file = ( + self.runtime._load_for_executorch_from_bundled_program( + bundled_module, ptd_file + ) + ) + executorch_output_bundled_data_file = ( + executorch_module_bundled_data_file.forward(inputs)[0] + ) + self.assertTrue( + torch.allclose(expected, executorch_output_bundled_data_file) + ) + # Test 6: Bundled program without external data should fail executorch_module_bundled_no_data = ( self.runtime._load_for_executorch_from_bundled_program(bundled_module)