Skip to content

Commit 7f38b5f

Browse files
RegisterExecutionProviderLibrary (microsoft#1628)
1 parent 0f0334b commit 7f38b5f

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

src/models/onnxruntime_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,14 @@ std::span<TAlloc> Allocate(OrtAllocator& allocator,
416416
return std::span(unique_ptr.get(), size);
417417
}
418418

419+
inline void RegisterExecutionProviderLibrary(OrtEnv* env, const char* registration_name, const ORTCHAR_T* path) {
420+
Ort::api->RegisterExecutionProviderLibrary(env, registration_name, path);
421+
}
422+
423+
inline void UnregisterExecutionProviderLibrary(OrtEnv* env, const char* registration_name) {
424+
Ort::api->UnregisterExecutionProviderLibrary(env, registration_name);
425+
}
426+
419427
} // namespace Ort
420428

421429
/** \brief The Status that holds ownership of OrtStatus received from C API

src/ort_genai_c.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,4 +983,12 @@ void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* p) { delete p; }
983983
void OGA_API_CALL OgaDestroyEngine(OgaEngine* p) { p->ExternalRelease(); }
984984
void OGA_API_CALL OgaDestroyRequest(OgaRequest* p) { p->ExternalRelease(); }
985985

986+
void OGA_API_CALL OgaRegisterExecutionProviderLibrary(const char* registration_name, const char* library_path) {
987+
Ort::RegisterExecutionProviderLibrary(&(Generators::GetOrtEnv()), registration_name, fs::path(library_path).c_str());
988+
}
989+
990+
void OGA_API_CALL OgaUnregisterExecutionProviderLibrary(const char* registration_name) {
991+
Ort::UnregisterExecutionProviderLibrary(&(Generators::GetOrtEnv()), registration_name);
992+
}
993+
986994
} // extern "C"

src/ort_genai_c.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,21 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaRequestGetUnseenToken(OgaRequest* request,
968968
*/
969969
OGA_EXPORT OgaResult* OGA_API_CALL OgaRequestIsDone(const OgaRequest* request, bool* out);
970970

971+
/**
972+
* \brief Registers an execution provider library with ONNXRuntime API.
973+
* \param registration_name name for registration.
974+
* \param path provider path.
975+
*
976+
*/
977+
OGA_EXPORT void OGA_API_CALL OgaRegisterExecutionProviderLibrary(const char* registration_name, const char* library_path);
978+
979+
/**
980+
* \brief Unregisters an execution provider library with ONNXRuntime API.
981+
* \param registration_name name for registration.
982+
*
983+
*/
984+
OGA_EXPORT void OGA_API_CALL OgaUnregisterExecutionProviderLibrary(const char* registration_name);
985+
971986
#ifdef __cplusplus
972987
}
973988
#endif

src/python/python.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,12 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
584584

585585
m.def("set_current_gpu_device_id", [](int device_id) { Ort::SetCurrentGpuDeviceId(device_id); });
586586
m.def("get_current_gpu_device_id", []() { return Ort::GetCurrentGpuDeviceId(); });
587-
}
587+
588+
m.def("register_execution_provider_library", [](const std::string& provider_name, const std::string& path_str) {
589+
OgaRegisterExecutionProviderLibrary(provider_name.c_str(), path_str.c_str());
590+
});
591+
592+
m.def("unregister_execution_provider_library", [](const std::string& provider_name) {
593+
OgaUnregisterExecutionProviderLibrary(provider_name.c_str());
594+
});
595+
}

0 commit comments

Comments
 (0)