@@ -965,6 +965,89 @@ struct PyModule final {
965965 }
966966};
967967
968+ inline std::unique_ptr<DataLoader> loader_from_buffer (
969+ const void * ptr,
970+ size_t ptr_len) {
971+ return std::make_unique<BufferDataLoader>(ptr, ptr_len);
972+ }
973+
974+ inline std::unique_ptr<DataLoader> loader_from_file (const std::string& path) {
975+ Result<MmapDataLoader> res = MmapDataLoader::from (
976+ path.c_str (), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
977+ THROW_IF_ERROR (
978+ res.error (),
979+ " Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
980+ path.c_str (),
981+ static_cast <uint32_t >(res.error ()));
982+
983+ return std::make_unique<MmapDataLoader>(std::move (res.get ()));
984+ }
985+
986+ inline std::unique_ptr<Program> load_program (
987+ DataLoader* loader,
988+ Program::Verification program_verification) {
989+ Result<Program> res = Program::load (loader, program_verification);
990+ THROW_IF_ERROR (
991+ res.error (),
992+ " Failed to load program, error: 0x:%" PRIx32,
993+ static_cast <uint32_t >(res.error ()));
994+ return std::make_unique<Program>(std::move (res.get ()));
995+ }
996+
997+ struct PyProgram final {
998+ explicit PyProgram (
999+ const py::bytes& buffer,
1000+ Program::Verification program_verification =
1001+ Program::Verification::Minimal)
1002+ : loader_(loader_from_buffer(
1003+ buffer.cast<std::string_view>().data(),
1004+ py::len(buffer))),
1005+ program_(load_program(loader_.get(), program_verification)) {}
1006+
1007+ explicit PyProgram (
1008+ const std::string& path,
1009+ Program::Verification program_verification =
1010+ Program::Verification::Minimal)
1011+ : loader_(loader_from_file(path)),
1012+ program_(load_program(loader_.get(), program_verification)) {}
1013+
1014+ static std::unique_ptr<PyProgram> load_from_buffer (
1015+ const py::bytes& buffer,
1016+ Program::Verification program_verification =
1017+ Program::Verification::Minimal) {
1018+ return std::make_unique<PyProgram>(buffer, program_verification);
1019+ }
1020+
1021+ static std::unique_ptr<PyProgram> load_from_file (
1022+ const std::string& path,
1023+ Program::Verification program_verification =
1024+ Program::Verification::Minimal) {
1025+ return std::make_unique<PyProgram>(path, program_verification);
1026+ }
1027+
1028+ PyProgram (const PyProgram&) = delete ;
1029+ PyProgram& operator =(const PyProgram&) = delete ;
1030+ PyProgram (PyProgram&&) = default ;
1031+ PyProgram& operator =(PyProgram&&) = default ;
1032+
1033+ size_t num_methods () const {
1034+ return program_->num_methods ();
1035+ }
1036+
1037+ std::string get_method_name (size_t method_index) const {
1038+ Result<const char *> res = program_->get_method_name (method_index);
1039+ THROW_IF_ERROR (
1040+ res.error (),
1041+ " Failed get method name, error: 0x:%" PRIx32,
1042+ static_cast <uint32_t >(res.error ()));
1043+ return std::string (res.get ());
1044+ }
1045+
1046+ private:
1047+ std::unique_ptr<DataLoader> loader_;
1048+ std::unique_ptr<Program> program_;
1049+ };
1050+
9681051void create_profile_block (const std::string& name) {
9691052 EXECUTORCH_PROFILE_CREATE_BLOCK (name.c_str ());
9701053}
@@ -1151,6 +1234,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11511234 py::arg (" index" ),
11521235 call_guard)
11531236 .def (" __repr__" , &PyMethodMeta::repr, call_guard);
1237+
1238+ m.def (
1239+ " _load_program" ,
1240+ &PyProgram::load_from_file,
1241+ py::arg (" path" ),
1242+ py::arg (" program_verification" ) = Program::Verification::Minimal,
1243+ call_guard);
1244+ m.def (
1245+ " _load_program_from_buffer" ,
1246+ &PyProgram::load_from_buffer,
1247+ py::arg (" buffer" ),
1248+ py::arg (" program_verification" ) = Program::Verification::Minimal,
1249+ call_guard);
1250+ py::class_<PyProgram>(m, " ExecuTorchProgram" )
1251+ .def (" num_methods" , &PyProgram::num_methods, call_guard)
1252+ .def (
1253+ " get_method_name" ,
1254+ &PyProgram::get_method_name,
1255+ py::arg (" method_index" ),
1256+ call_guard);
11541257}
11551258
11561259namespace {
0 commit comments