Skip to content

Commit 1e5315c

Browse files
authored
feat: add directml support (k2-fsa#1153)
1 parent 3cd0598 commit 1e5315c

File tree

6 files changed

+218
-7
lines changed

6 files changed

+218
-7
lines changed

CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
3030
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
3131
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
3232
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
33+
option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
3334
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
3435
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
3536
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
@@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.")
9495
endif()
9596
endif()
9697

98+
if(SHERPA_ONNX_ENABLE_DIRECTML)
99+
message(WARNING "\
100+
Compiling with DirectML enabled. Please make sure Windows 10 SDK
101+
is installed on your system. Otherwise, you will get errors at runtime.
102+
Please refer to
103+
https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements
104+
to install Windows 10 SDK if you have not installed it.")
105+
if(NOT BUILD_SHARED_LIBS)
106+
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON")
107+
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
108+
endif()
109+
endif()
110+
97111
# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
98112
# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
99113
if(MSVC)
@@ -160,6 +174,14 @@ else()
160174
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
161175
endif()
162176

177+
if(SHERPA_ONNX_ENABLE_DIRECTML)
178+
message(STATUS "DirectML is enabled")
179+
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
180+
else()
181+
message(WARNING "DirectML is disabled")
182+
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
183+
endif()
184+
163185
if(SHERPA_ONNX_ENABLE_WASM_TTS)
164186
if(NOT SHERPA_ONNX_ENABLE_TTS)
165187
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) 2022-2023 Xiaomi Corporation
2+
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
3+
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
4+
message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}")
5+
6+
if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows)
7+
message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}")
8+
endif()
9+
10+
if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64))
11+
message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}")
12+
endif()
13+
14+
if(NOT BUILD_SHARED_LIBS)
15+
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
16+
endif()
17+
18+
if(NOT SHERPA_ONNX_ENABLE_DIRECTML)
19+
message(FATAL_ERROR "This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML: ${SHERPA_ONNX_ENABLE_DIRECTML}")
20+
endif()
21+
22+
set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
23+
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
24+
set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a")
25+
26+
# If you don't have access to the Internet,
27+
# please download onnxruntime to one of the following locations.
28+
# You can add more if you want.
29+
set(possible_file_locations
30+
$ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
31+
${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
32+
${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
33+
/tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
34+
)
35+
36+
foreach(f IN LISTS possible_file_locations)
37+
if(EXISTS ${f})
38+
set(onnxruntime_URL "${f}")
39+
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
40+
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
41+
set(onnxruntime_URL2)
42+
break()
43+
endif()
44+
endforeach()
45+
46+
FetchContent_Declare(onnxruntime
47+
URL
48+
${onnxruntime_URL}
49+
${onnxruntime_URL2}
50+
URL_HASH ${onnxruntime_HASH}
51+
)
52+
53+
FetchContent_GetProperties(onnxruntime)
54+
if(NOT onnxruntime_POPULATED)
55+
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
56+
FetchContent_Populate(onnxruntime)
57+
endif()
58+
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
59+
60+
find_library(location_onnxruntime onnxruntime
61+
PATHS
62+
"${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native"
63+
NO_CMAKE_SYSTEM_PATH
64+
)
65+
66+
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
67+
68+
add_library(onnxruntime SHARED IMPORTED)
69+
70+
set_target_properties(onnxruntime PROPERTIES
71+
IMPORTED_LOCATION ${location_onnxruntime}
72+
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include"
73+
)
74+
75+
set_property(TARGET onnxruntime
76+
PROPERTY
77+
IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib"
78+
)
79+
80+
file(COPY ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll
81+
DESTINATION
82+
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
83+
)
84+
85+
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.*")
86+
87+
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
88+
89+
if(SHERPA_ONNX_ENABLE_PYTHON)
90+
install(FILES ${onnxruntime_lib_files} DESTINATION ..)
91+
else()
92+
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
93+
endif()
94+
95+
install(FILES ${onnxruntime_lib_files} DESTINATION bin)
96+
97+
# Setup DirectML
98+
99+
set(directml_URL "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0")
100+
set(directml_HASH "SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23")
101+
102+
set(possible_directml_file_locations
103+
$ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg
104+
${PROJECT_SOURCE_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
105+
${PROJECT_BINARY_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
106+
/tmp/Microsoft.AI.DirectML.1.15.0.nupkg
107+
)
108+
109+
foreach(f IN LISTS possible_directml_file_locations)
110+
if(EXISTS ${f})
111+
set(directml_URL "${f}")
112+
file(TO_CMAKE_PATH "${directml_URL}" directml_URL)
113+
message(STATUS "Found local downloaded DirectML: ${directml_URL}")
114+
break()
115+
endif()
116+
endforeach()
117+
118+
FetchContent_Declare(directml
119+
URL
120+
${directml_URL}
121+
URL_HASH ${directml_HASH}
122+
)
123+
124+
FetchContent_GetProperties(directml)
125+
if(NOT directml_POPULATED)
126+
message(STATUS "Downloading DirectML from ${directml_URL}")
127+
FetchContent_Populate(directml)
128+
endif()
129+
message(STATUS "DirectML is downloaded to ${directml_SOURCE_DIR}")
130+
131+
find_library(location_directml DirectML
132+
PATHS
133+
"${directml_SOURCE_DIR}/bin/x64-win"
134+
NO_CMAKE_SYSTEM_PATH
135+
)
136+
137+
message(STATUS "location_directml: ${location_directml}")
138+
139+
add_library(directml SHARED IMPORTED)
140+
141+
set_target_properties(directml PROPERTIES
142+
IMPORTED_LOCATION ${location_directml}
143+
INTERFACE_INCLUDE_DIRECTORIES "${directml_SOURCE_DIR}/bin/x64-win"
144+
)
145+
146+
set_property(TARGET directml
147+
PROPERTY
148+
IMPORTED_IMPLIB "${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib"
149+
)
150+
151+
file(COPY ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll
152+
DESTINATION
153+
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
154+
)
155+
156+
file(GLOB directml_lib_files "${directml_SOURCE_DIR}/bin/x64-win/DirectML.*")
157+
158+
message(STATUS "DirectML lib files: ${directml_lib_files}")
159+
160+
install(FILES ${directml_lib_files} DESTINATION lib)
161+
install(FILES ${directml_lib_files} DESTINATION bin)

cmake/onnxruntime.cmake

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ function(download_onnxruntime)
9595
include(onnxruntime-win-arm64)
9696
else()
9797
# for 64-bit windows (x64)
98-
if(BUILD_SHARED_LIBS)
98+
if(SHERPA_ONNX_ENABLE_DIRECTML)
99+
message(STATUS "Use DirectML")
100+
include(onnxruntime-win-x64-directml)
101+
elseif(BUILD_SHARED_LIBS)
99102
message(STATUS "Use dynamic onnxruntime libraries")
100103
if(SHERPA_ONNX_ENABLE_GPU)
101104
include(onnxruntime-win-x64-gpu)

sherpa-onnx/csrc/provider.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
2626
return Provider::kNNAPI;
2727
} else if (s == "trt") {
2828
return Provider::kTRT;
29+
} else if (s == "directml") {
30+
return Provider::kDirectML;
2931
} else {
3032
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
3133
return Provider::kCPU;

sherpa-onnx/csrc/provider.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ namespace sherpa_onnx {
1414
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
1515
// for a list of available providers
1616
enum class Provider {
17-
kCPU = 0, // CPUExecutionProvider
18-
kCUDA = 1, // CUDAExecutionProvider
19-
kCoreML = 2, // CoreMLExecutionProvider
20-
kXnnpack = 3, // XnnpackExecutionProvider
21-
kNNAPI = 4, // NnapiExecutionProvider
22-
kTRT = 5, // TensorRTExecutionProvider
17+
kCPU = 0, // CPUExecutionProvider
18+
kCUDA = 1, // CUDAExecutionProvider
19+
kCoreML = 2, // CoreMLExecutionProvider
20+
kXnnpack = 3, // XnnpackExecutionProvider
21+
kNNAPI = 4, // NnapiExecutionProvider
22+
kTRT = 5, // TensorRTExecutionProvider
23+
kDirectML = 6, // DmlExecutionProvider
2324
};
2425

2526
/**

sherpa-onnx/csrc/session.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
#include "nnapi_provider_factory.h" // NOLINT
2020
#endif
2121

22+
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
23+
#include "dml_provider_factory.h" // NOLINT
24+
#endif
25+
2226
namespace sherpa_onnx {
2327

2428
static void OrtStatusFailure(OrtStatus *status, const char *s) {
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
167171
}
168172
break;
169173
}
174+
case Provider::kDirectML: {
175+
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
176+
sess_opts.DisableMemPattern();
177+
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
178+
int32_t device_id = 0;
179+
OrtStatus *status =
180+
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
181+
if (status) {
182+
const auto &api = Ort::GetApi();
183+
const char *msg = api.GetErrorMessage(status);
184+
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
185+
api.ReleaseStatus(status);
186+
}
187+
#else
188+
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
189+
#endif
190+
break;
191+
}
170192
case Provider::kCoreML: {
171193
#if defined(__APPLE__)
172194
uint32_t coreml_flags = 0;

0 commit comments

Comments
 (0)