Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit fff435f

Browse files
authored
Support hardware acceleration on non-cuda platforms (#21)
* Support hardware acceleration on non-cuda platforms * Ensure hardware accel is used on DML * Fix accidental delete * Fix AMD
1 parent f2a0519 commit fff435f

File tree

6 files changed

+29
-13
lines changed

6 files changed

+29
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ros_msft_onnx/testdata/model.onnx

ros_msft_onnx/CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ if(MSVC)
1717

1818
endif()
1919

20+
add_definitions(-DBOOST_BIND_GLOBAL_PLACEHOLDERS)
21+
22+
2023
option(CUDA_SUPPORT "use CUDA support onnxruntime library" OFF)
2124

2225
find_package(Eigen3 REQUIRED)
@@ -65,9 +68,9 @@ if(CUDA_SUPPORT)
6568
set(PACKAGE_URL "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.Gpu/1.7.1")
6669
set(PACKAGE_SHA512 "41112118007aae34fcc38100152df6e6fa5fc567e61aa4ded42a26d39751f1be7ec225c0d73799f065015e284f0fb9bd7e0835c733e9abad5b0243a391411f8d")
6770
else()
68-
set(ONNX_RUNTIME "Microsoft.ML.OnnxRuntime.1.7.0")
69-
set(PACKAGE_URL "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime/1.7.0")
70-
set(PACKAGE_SHA512 "1fc15386bdfa455f457e50899e3c9c454aafbdc345799dcf4ecfd6990a9dbd8cd7f0b1f3bf412c47c900543c535f95aa1cb1e14e9851cb9b600c60a981f38a50")
71+
set(ONNX_RUNTIME "Microsoft.ML.OnnxRuntime.DirectML.1.7.0")
72+
set(PACKAGE_URL "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.DirectML/1.7.0")
73+
set(PACKAGE_SHA512 "2e5bd2c0ade72444d4efdfbd6a75571aaa72045769f9b5847186129c9e5e667ad080d5d2b9a12cce88c9eee68302be89cdb7030ccefa3d572e591b1c453c7340")
7174
endif()
7275

7376
file(DOWNLOAD

ros_msft_onnx/launch/tracker.launch

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
<node pkg="ros_msft_onnx" type="ros_msft_onnx_node" name="ros_msft_onnx" output="screen">
55
<param name="image_topic" value="$(eval '/camera/image_raw' if os_windows_arg else '/cv_camera/image_raw')"/>
6+
<param name="tensor_width" value="416"/>
7+
<param name="tensor_height" value="416"/>
68
</node>
79

810
<!-- The camera node will be selected based on os. ros_msft_camera for Windows and cv_camera for others. -->

ros_msft_onnx/src/main.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66

77
#include "ros_msft_onnx/ros_msft_onnx.h"
88

9+
#ifdef _WIN32
10+
#include <objbase.h>
11+
#endif
12+
913
using namespace std;
1014

1115
int main(int argc, char **argv)
1216
{
13-
/*
14-
ROS_WARN("ONNX: Waiting for Debugger");
15-
while (!IsDebuggerPresent())
16-
{
17-
Sleep(5);
18-
}
19-
*/
17+
#ifdef _WIN32
18+
HRESULT hr = CoInitializeEx(NULL, COINIT_MULTITHREADED);
19+
#endif
2020

2121
ros::init(argc, argv, "ros_msft_onnx");
2222

@@ -39,5 +39,7 @@ int main(int argc, char **argv)
3939
return 1;
4040
}
4141

42-
42+
#ifdef _WIN32
43+
CoUninitialize();
44+
#endif
4345
}

ros_msft_onnx/src/ros_msft_onnx.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ros/ros.h>
2+
#include <boost/bind/bind.hpp>
23
#include <cv_bridge/cv_bridge.h>
34
#include <image_transport/image_transport.h>
45
#include <visualization_msgs/MarkerArray.h>
@@ -13,6 +14,9 @@
1314
#include <codecvt>
1415
#include <fstream>
1516
#include <sstream>
17+
#ifdef _WIN32
18+
#include "dml_provider_factory.h"
19+
#endif
1620

1721
using namespace std;
1822

@@ -141,11 +145,16 @@ bool OnnxProcessor::init(ros::NodeHandle &nh, ros::NodeHandle &nhPrivate)
141145
// initialize session options if needed
142146
Ort::SessionOptions session_options;
143147
session_options.SetIntraOpNumThreads(1);
148+
session_options.DisableMemPattern(); // Required for DirectML
149+
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); // Required for DirectML
144150
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
145151

146152
std::string modelString = _onnxModel;
147153

148154
#ifdef _WIN32
155+
// Select Device 0
156+
OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*)session_options, 0);
157+
149158
auto modelFullPath = to_wstring(modelString).c_str();
150159
#else
151160
auto modelFullPath = _onnxModel.c_str();
@@ -290,7 +299,7 @@ bool OnnxTracker::init(ros::NodeHandle &nh, ros::NodeHandle &nhPrivate)
290299
_nh = nh;
291300
_nhPrivate = nhPrivate;
292301

293-
f = boost::bind(&OnnxTracker::callback, this, _1, _2);
302+
f = boost::bind(&OnnxTracker::callback, this, boost::placeholders::_1, boost::placeholders::_2);
294303
server.setCallback(f);
295304

296305
return _status;

ros_msft_onnx/src/yolo_box.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ namespace yolo
112112

113113
if (_debug)
114114
{
115-
ROS_INFO("ONNX: %s detected!", label.c_str());
116115
// Draw a bounding box on the CV image
117116
cv::Scalar color(255, 255, 0);
118117
cv::Rect box;

0 commit comments

Comments
 (0)