Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 44 additions & 46 deletions cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function(download_onnxruntime)
include(FetchContent)

set(LIB_PATH)
if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
# For embedded systems
set(possible_file_locations
Expand Down Expand Up @@ -54,43 +55,23 @@ function(download_onnxruntime)
elseif(WIN32)
message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}")

set(possible_file_locations
$ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
/tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
)

set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a")

if(CMAKE_VS_PLATFORM_NAME STREQUAL Win32)
# If you don't have access to the Internet,
# please pre-download onnxruntime
#
# for 32-bit windows
set(possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-win-x86-1.14.0.zip
${PROJECT_SOURCE_DIR}/onnxruntime-win-x86-1.14.0.zip
${PROJECT_BINARY_DIR}/onnxruntime-win-x86-1.14.0.zip
/tmp/onnxruntime-win-x86-1.14.0.zip
)

set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-win-x86-1.14.0.zip")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/onnxruntime-win-x86-1.14.0.zip")
set(onnxruntime_HASH "SHA256=4214b130db602cbf31a6f26f25377ab077af0cf03c4ddd4651283e1fb68f56cf")
set(LIB_PATH "runtimes/win-x86/native/")
else()
# If you don't have access to the Internet,
# please pre-download onnxruntime
#
# for 64-bit windows
set(possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-win-x64-1.14.0.zip
${PROJECT_SOURCE_DIR}/onnxruntime-win-x64-1.14.0.zip
${PROJECT_BINARY_DIR}/onnxruntime-win-x64-1.14.0.zip
/tmp/onnxruntime-win-x64-1.14.0.zip
)

set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-win-x64-1.14.0.zip")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/onnxruntime-win-x64-1.14.0.zip")
set(onnxruntime_HASH "SHA256=300eafef456748cde2743ee08845bd40ff1bab723697ff934eba6d4ce3519620")
set(LIB_PATH "runtimes/win-x64/native/")
# TODO(fangjun): Support win-arm and win-arm64
endif()
# After downloading, it contains:
# ./lib/onnxruntime.{dll,lib,pdb}
# ./lib/onnxruntime_providers_shared.{dll,lib,pdb}
#
# ./include
# It contains all the needed header files
else()
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
Expand Down Expand Up @@ -119,28 +100,45 @@ function(download_onnxruntime)
FetchContent_Populate(onnxruntime)
endif()
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
if(WIN32)
set(LIB_PATH "${onnxruntime_SOURCE_DIR}/${LIB_PATH}")
endif()

find_library(location_onnxruntime onnxruntime
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
message(STATUS "Addition lib search path for onnxruntime: ${LIB_PATH}")

if(NOT WIN32)
find_library(location_onnxruntime onnxruntime
PATHS
"${onnxruntime_SOURCE_DIR}/lib"
NO_CMAKE_SYSTEM_PATH
)
else()
set(location_onnxruntime ${LIB_PATH}/onnxruntime.dll)
endif()

message(STATUS "location_onnxruntime: ${location_onnxruntime}")

add_library(onnxruntime SHARED IMPORTED)

set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime}
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
)
if(NOT WIN32)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime}
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
)
else()
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime}
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include"
)
endif()

if(WIN32)
set_property(TARGET onnxruntime
PROPERTY
IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/lib/onnxruntime.lib"
IMPORTED_IMPLIB "${LIB_PATH}/onnxruntime.lib"
)

file(COPY ${onnxruntime_SOURCE_DIR}/lib/onnxruntime.dll
file(COPY ${LIB_PATH}/onnxruntime.dll
DESTINATION
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
)
Expand All @@ -151,7 +149,7 @@ function(download_onnxruntime)
elseif(APPLE)
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.*.*dylib")
elseif(WIN32)
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/lib/*.dll")
file(GLOB onnxruntime_lib_files "${LIB_PATH}/*.dll")
endif()

message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Provider StringToProvider(std::string s) {
return Provider::kCUDA;
} else if (s == "coreml") {
return Provider::kCoreML;
} else if (s == "directml") {
return Provider::kDirectML;
} else {
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
return Provider::kCPU;
Expand Down
7 changes: 4 additions & 3 deletions sherpa-onnx/csrc/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ namespace sherpa_onnx {
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
// for a list of available providers
enum class Provider {
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kDirectML = 3, // DmlExecutionProvider
};

/**
Expand Down
35 changes: 33 additions & 2 deletions sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "sherpa-onnx/csrc/session.h"

#include <string.h>

#include <string>
#include <utility>

Expand All @@ -13,6 +15,10 @@
#include "coreml_provider_factory.h" // NOLINT
#endif

#if defined(_WIN32)
#include "dml_provider_factory.h" // NOLINT
#endif

namespace sherpa_onnx {

static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
Expand All @@ -23,23 +29,48 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
sess_opts.SetIntraOpNumThreads(num_threads);
sess_opts.SetInterOpNumThreads(num_threads);

const auto &api = Ort::GetApi();

switch (p) {
case Provider::kCPU:
break; // nothing to do for the CPU provider
case Provider::kCUDA: {
OrtCUDAProviderOptions options;
options.device_id = 0;

// set more options on need
sess_opts.AppendExecutionProvider_CUDA(options);
break;
}
case Provider::kCoreML: {
#if defined(__APPLE__)
uint32_t coreml_flags = 0;
(void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts,
coreml_flags);
OrtStatus *status = OrtSessionOptionsAppendExecutionProvider_CoreML(
sess_opts, coreml_flags);
if (status) {
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE("Failed to enable CoreML: %s. Fallback to cpu", msg);
api.ReleaseStatus(status);
}
#else
SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!");
#endif
break;
}
case Provider::kDirectML: {
#if defined(_WIN32)
sess_opts.DisableMemPattern();
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
int32_t device_id = 0;
OrtStatus *status =
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
if (status) {
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
api.ReleaseStatus(status);
}
#else
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
#endif
break;
}
Expand Down