@@ -965,6 +965,89 @@ struct PyModule final {
965
965
}
966
966
};
967
967
968
+ inline std::unique_ptr<DataLoader> loader_from_buffer (
969
+ const void * ptr,
970
+ size_t ptr_len) {
971
+ return std::make_unique<BufferDataLoader>(ptr, ptr_len);
972
+ }
973
+
974
+ inline std::unique_ptr<DataLoader> loader_from_file (const std::string& path) {
975
+ Result<MmapDataLoader> res = MmapDataLoader::from (
976
+ path.c_str (), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
977
+ THROW_IF_ERROR (
978
+ res.error (),
979
+ " Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
980
+ path.c_str (),
981
+ static_cast <uint32_t >(res.error ()));
982
+
983
+ return std::make_unique<MmapDataLoader>(std::move (res.get ()));
984
+ }
985
+
986
+ inline std::unique_ptr<Program> load_program (
987
+ DataLoader* loader,
988
+ Program::Verification program_verification) {
989
+ Result<Program> res = Program::load (loader, program_verification);
990
+ THROW_IF_ERROR (
991
+ res.error (),
992
+ " Failed to load program, error: 0x:%" PRIx32,
993
+ static_cast <uint32_t >(res.error ()));
994
+ return std::make_unique<Program>(std::move (res.get ()));
995
+ }
996
+
997
+ struct PyProgram final {
998
+ explicit PyProgram (
999
+ const py::bytes& buffer,
1000
+ Program::Verification program_verification =
1001
+ Program::Verification::Minimal)
1002
+ : loader_(loader_from_buffer(
1003
+ buffer.cast<std::string_view>().data(),
1004
+ py::len(buffer))),
1005
+ program_(load_program(loader_.get(), program_verification)) {}
1006
+
1007
+ explicit PyProgram (
1008
+ const std::string& path,
1009
+ Program::Verification program_verification =
1010
+ Program::Verification::Minimal)
1011
+ : loader_(loader_from_file(path)),
1012
+ program_(load_program(loader_.get(), program_verification)) {}
1013
+
1014
+ static std::unique_ptr<PyProgram> load_from_buffer (
1015
+ const py::bytes& buffer,
1016
+ Program::Verification program_verification =
1017
+ Program::Verification::Minimal) {
1018
+ return std::make_unique<PyProgram>(buffer, program_verification);
1019
+ }
1020
+
1021
+ static std::unique_ptr<PyProgram> load_from_file (
1022
+ const std::string& path,
1023
+ Program::Verification program_verification =
1024
+ Program::Verification::Minimal) {
1025
+ return std::make_unique<PyProgram>(path, program_verification);
1026
+ }
1027
+
1028
+ PyProgram (const PyProgram&) = delete ;
1029
+ PyProgram& operator =(const PyProgram&) = delete ;
1030
+ PyProgram (PyProgram&&) = default ;
1031
+ PyProgram& operator =(PyProgram&&) = default ;
1032
+
1033
+ size_t num_methods () const {
1034
+ return program_->num_methods ();
1035
+ }
1036
+
1037
+ std::string get_method_name (size_t method_index) const {
1038
+ Result<const char *> res = program_->get_method_name (method_index);
1039
+ THROW_IF_ERROR (
1040
+ res.error (),
1041
+ " Failed get method name, error: 0x:%" PRIx32,
1042
+ static_cast <uint32_t >(res.error ()));
1043
+ return std::string (res.get ());
1044
+ }
1045
+
1046
+ private:
1047
+ std::unique_ptr<DataLoader> loader_;
1048
+ std::unique_ptr<Program> program_;
1049
+ };
1050
+
968
1051
void create_profile_block (const std::string& name) {
969
1052
EXECUTORCH_PROFILE_CREATE_BLOCK (name.c_str ());
970
1053
}
@@ -1151,6 +1234,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1151
1234
py::arg (" index" ),
1152
1235
call_guard)
1153
1236
.def (" __repr__" , &PyMethodMeta::repr, call_guard);
1237
+
1238
+ m.def (
1239
+ " _load_program" ,
1240
+ &PyProgram::load_from_file,
1241
+ py::arg (" path" ),
1242
+ py::arg (" program_verification" ) = Program::Verification::Minimal,
1243
+ call_guard);
1244
+ m.def (
1245
+ " _load_program_from_buffer" ,
1246
+ &PyProgram::load_from_buffer,
1247
+ py::arg (" buffer" ),
1248
+ py::arg (" program_verification" ) = Program::Verification::Minimal,
1249
+ call_guard);
1250
+ py::class_<PyProgram>(m, " ExecuTorchProgram" )
1251
+ .def (" num_methods" , &PyProgram::num_methods, call_guard)
1252
+ .def (
1253
+ " get_method_name" ,
1254
+ &PyProgram::get_method_name,
1255
+ py::arg (" method_index" ),
1256
+ call_guard);
1154
1257
}
1155
1258
1156
1259
namespace {
0 commit comments