|
1 | 1 | #include <ros/ros.h> |
| 2 | +#include <boost/bind/bind.hpp> |
2 | 3 | #include <cv_bridge/cv_bridge.h> |
3 | 4 | #include <image_transport/image_transport.h> |
4 | 5 | #include <visualization_msgs/MarkerArray.h> |
|
13 | 14 | #include <codecvt> |
14 | 15 | #include <fstream> |
15 | 16 | #include <sstream> |
| 17 | +#ifdef _WIN32 |
| 18 | +#include "dml_provider_factory.h" |
| 19 | +#endif |
16 | 20 |
|
17 | 21 | using namespace std; |
18 | 22 |
|
@@ -141,11 +145,16 @@ bool OnnxProcessor::init(ros::NodeHandle &nh, ros::NodeHandle &nhPrivate) |
141 | 145 | // initialize session options if needed |
142 | 146 | Ort::SessionOptions session_options; |
143 | 147 | session_options.SetIntraOpNumThreads(1); |
| 148 | + session_options.DisableMemPattern(); // Required for DirectML |
| 149 | + session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); // Required for DirectML |
144 | 150 | session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); |
145 | 151 |
|
146 | 152 | std::string modelString = _onnxModel; |
147 | 153 |
|
148 | 154 | #ifdef _WIN32 |
| 155 | + // Select Device 0 |
| 156 | + OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*)session_options, 0); |
| 157 | + |
149 | 158 | auto modelFullPath = to_wstring(modelString).c_str(); |
150 | 159 | #else |
151 | 160 | auto modelFullPath = _onnxModel.c_str(); |
@@ -290,7 +299,7 @@ bool OnnxTracker::init(ros::NodeHandle &nh, ros::NodeHandle &nhPrivate) |
290 | 299 | _nh = nh; |
291 | 300 | _nhPrivate = nhPrivate; |
292 | 301 |
|
293 | | - f = boost::bind(&OnnxTracker::callback, this, _1, _2); |
| 302 | + f = boost::bind(&OnnxTracker::callback, this, boost::placeholders::_1, boost::placeholders::_2); |
294 | 303 | server.setCallback(f); |
295 | 304 |
|
296 | 305 | return _status; |
|
0 commit comments