diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 202b586cfa..73eb814336 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -1,6 +1,7 @@ function(download_onnxruntime) include(FetchContent) + set(LIB_PATH) if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) # For embedded systems set(possible_file_locations @@ -54,43 +55,23 @@ function(download_onnxruntime) elseif(WIN32) message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}") + set(possible_file_locations + $ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + /tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ) + + set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") + set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") + set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a") + if(CMAKE_VS_PLATFORM_NAME STREQUAL Win32) - # If you don't have access to the Internet, - # please pre-download onnxruntime - # - # for 32-bit windows - set(possible_file_locations - $ENV{HOME}/Downloads/onnxruntime-win-x86-1.14.0.zip - ${PROJECT_SOURCE_DIR}/onnxruntime-win-x86-1.14.0.zip - ${PROJECT_BINARY_DIR}/onnxruntime-win-x86-1.14.0.zip - /tmp/onnxruntime-win-x86-1.14.0.zip - ) - - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-win-x86-1.14.0.zip") - set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/onnxruntime-win-x86-1.14.0.zip") - set(onnxruntime_HASH "SHA256=4214b130db602cbf31a6f26f25377ab077af0cf03c4ddd4651283e1fb68f56cf") + set(LIB_PATH "runtimes/win-x86/native/") else() - # If you don't have access to the Internet, - # please pre-download onnxruntime - # - # for 64-bit windows - set(possible_file_locations - $ENV{HOME}/Downloads/onnxruntime-win-x64-1.14.0.zip - ${PROJECT_SOURCE_DIR}/onnxruntime-win-x64-1.14.0.zip - ${PROJECT_BINARY_DIR}/onnxruntime-win-x64-1.14.0.zip - /tmp/onnxruntime-win-x64-1.14.0.zip - ) - - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-win-x64-1.14.0.zip") - set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/onnxruntime-win-x64-1.14.0.zip") - set(onnxruntime_HASH "SHA256=300eafef456748cde2743ee08845bd40ff1bab723697ff934eba6d4ce3519620") + set(LIB_PATH "runtimes/win-x64/native/") + # TODO(fangjun): Support win-arm and win-arm64 endif() - # After downloading, it contains: - # ./lib/onnxruntime.{dll,lib,pdb} - # ./lib/onnxruntime_providers_shared.{dll,lib,pdb} - # - # ./include - # It contains all the needed header files else() message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") @@ -119,28 +100,45 @@ function(download_onnxruntime) FetchContent_Populate(onnxruntime) endif() message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") + if(WIN32) + set(LIB_PATH "${onnxruntime_SOURCE_DIR}/${LIB_PATH}") + endif() - find_library(location_onnxruntime onnxruntime - PATHS - "${onnxruntime_SOURCE_DIR}/lib" - NO_CMAKE_SYSTEM_PATH - ) + message(STATUS "Addition lib search path for onnxruntime: ${LIB_PATH}") + + if(NOT WIN32) + find_library(location_onnxruntime onnxruntime + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + NO_CMAKE_SYSTEM_PATH + ) + else() + set(location_onnxruntime ${LIB_PATH}/onnxruntime.dll) + endif() message(STATUS "location_onnxruntime: ${location_onnxruntime}") add_library(onnxruntime SHARED IMPORTED) - set_target_properties(onnxruntime PROPERTIES - IMPORTED_LOCATION ${location_onnxruntime} - INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" - ) + if(NOT WIN32) + set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime} + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" + ) + else() + set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime} + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include" + ) + endif() + if(WIN32) set_property(TARGET onnxruntime PROPERTY - IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/lib/onnxruntime.lib" + IMPORTED_IMPLIB "${LIB_PATH}/onnxruntime.lib" ) - file(COPY ${onnxruntime_SOURCE_DIR}/lib/onnxruntime.dll + file(COPY ${LIB_PATH}/onnxruntime.dll DESTINATION ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} ) @@ -151,7 +149,7 @@ function(download_onnxruntime) elseif(APPLE) file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.*.*dylib") elseif(WIN32) - file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/*.dll") + file(GLOB onnxruntime_lib_files "${LIB_PATH}/*.dll") endif() message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc index 9c50eb8cc1..06bd9a54d9 100644 --- a/sherpa-onnx/csrc/provider.cc +++ b/sherpa-onnx/csrc/provider.cc @@ -20,6 +20,8 @@ Provider StringToProvider(std::string s) { return Provider::kCUDA; } else if (s == "coreml") { return Provider::kCoreML; + } else if (s == "directml") { + return Provider::kDirectML; } else { SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); return Provider::kCPU; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index 8e0dcc0a16..3670cf68dd 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -13,9 +13,10 @@ namespace sherpa_onnx { // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java // for a list of available providers enum class Provider { - kCPU = 0, // CPUExecutionProvider - kCUDA = 1, // CUDAExecutionProvider - kCoreML = 2, // CoreMLExecutionProvider + kCPU = 0, // CPUExecutionProvider + kCUDA = 1, // CUDAExecutionProvider + kCoreML = 2, // CoreMLExecutionProvider + kDirectML = 3, // DmlExecutionProvider }; /** diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 9920ec1701..de9744b479 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/session.h" +#include + #include #include @@ -13,6 +15,10 @@ #include "coreml_provider_factory.h" // NOLINT #endif +#if defined(_WIN32) +#include "dml_provider_factory.h" // NOLINT +#endif + namespace sherpa_onnx { static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -23,12 +29,15 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, sess_opts.SetIntraOpNumThreads(num_threads); sess_opts.SetInterOpNumThreads(num_threads); + const auto &api = Ort::GetApi(); + switch (p) { case Provider::kCPU: break; // nothing to do for the CPU provider case Provider::kCUDA: { OrtCUDAProviderOptions options; options.device_id = 0; + // set more options on need sess_opts.AppendExecutionProvider_CUDA(options); break; @@ -36,10 +45,32 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, case Provider::kCoreML: { #if defined(__APPLE__) uint32_t coreml_flags = 0; - (void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts, - coreml_flags); + OrtStatus *status = OrtSessionOptionsAppendExecutionProvider_CoreML( + sess_opts, coreml_flags); + if (status) { + const char *msg = api.GetErrorMessage(status); + SHERPA_ONNX_LOGE("Failed to enable CoreML: %s. Fallback to cpu", msg); + api.ReleaseStatus(status); + } #else SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!"); +#endif + break; + } + case Provider::kDirectML: { +#if defined(_WIN32) + sess_opts.DisableMemPattern(); + sess_opts.SetExecutionMode(ORT_SEQUENTIAL); + int32_t device_id = 0; + OrtStatus *status = + OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id); + if (status) { + const char *msg = api.GetErrorMessage(status); + SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg); + api.ReleaseStatus(status); + } +#else + SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!"); #endif break; }