@@ -44,17 +44,20 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
4444 if (!optionsMap.contains (" model-path" )) {
4545 LOG (fatal) << " (ORT) Model path cannot be empty!" ;
4646 }
47- modelPath = optionsMap[" model-path" ];
48- device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
49- dtype = (optionsMap.contains (" dtype" ) ? optionsMap[" dtype" ] : " float" );
50- deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : 0 );
51- allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
52- intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
53- loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 2 );
54- enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
55- enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
56-
57- std::string dev_mem_str = " Hip" ;
47+
48+ if (!optionsMap[" model-path" ].empty ()) {
49+ modelPath = optionsMap[" model-path" ];
50+ device = (optionsMap.contains (" device" ) ? optionsMap[" device" ] : " CPU" );
51+ dtype = (optionsMap.contains (" dtype" ) ? optionsMap[" dtype" ] : " float" );
52+ deviceId = (optionsMap.contains (" device-id" ) ? std::stoi (optionsMap[" device-id" ]) : 0 );
53+ allocateDeviceMemory = (optionsMap.contains (" allocate-device-memory" ) ? std::stoi (optionsMap[" allocate-device-memory" ]) : 0 );
54+ intraOpNumThreads = (optionsMap.contains (" intra-op-num-threads" ) ? std::stoi (optionsMap[" intra-op-num-threads" ]) : 0 );
55+ interOpNumThreads = (optionsMap.contains (" inter-op-num-threads" ) ? std::stoi (optionsMap[" inter-op-num-threads" ]) : 0 );
56+ loggingLevel = (optionsMap.contains (" logging-level" ) ? std::stoi (optionsMap[" logging-level" ]) : 0 );
57+ enableProfiling = (optionsMap.contains (" enable-profiling" ) ? std::stoi (optionsMap[" enable-profiling" ]) : 0 );
58+ enableOptimizations = (optionsMap.contains (" enable-optimizations" ) ? std::stoi (optionsMap[" enable-optimizations" ]) : 0 );
59+
60+ std::string dev_mem_str = " Hip" ;
5861#if defined(ORT_ROCM_BUILD)
5962#if ORT_ROCM_BUILD == 1
6063 if (device == " ROCM" ) {
@@ -88,12 +91,15 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
8891
8992 if (device == " CPU" ) {
9093 (pImplOrt->sessionOptions ).SetIntraOpNumThreads (intraOpNumThreads);
91- if (intraOpNumThreads > 1 ) {
94+ (pImplOrt->sessionOptions ).SetInterOpNumThreads (interOpNumThreads);
95+ if (intraOpNumThreads > 1 || interOpNumThreads > 1 ) {
9296 (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_PARALLEL);
9397 } else if (intraOpNumThreads == 1 ) {
9498 (pImplOrt->sessionOptions ).SetExecutionMode (ExecutionMode::ORT_SEQUENTIAL);
9599 }
96- LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " threads" ;
100+ if (loggingLevel < 2 ) {
101+ LOG (info) << " (ORT) CPU execution provider set with " << intraOpNumThreads << " (intraOpNumThreads) and " << interOpNumThreads << " (interOpNumThreads) threads" ;
102+ }
97103 }
98104
99105 (pImplOrt->sessionOptions ).DisableMemPattern ();
@@ -109,6 +115,9 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
109115 } else {
110116 (pImplOrt->sessionOptions ).DisableProfiling ();
111117 }
118+
119+ mInitialized = true ;
120+
112121 (pImplOrt->sessionOptions ).SetGraphOptimizationLevel (GraphOptimizationLevel (enableOptimizations));
113122 (pImplOrt->sessionOptions ).SetLogSeverityLevel (OrtLoggingLevel (loggingLevel));
114123
@@ -154,16 +163,9 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
154163 outputNamesChar.resize (mOutputNames .size (), nullptr );
155164 std::transform (std::begin (mOutputNames ), std::end (mOutputNames ), std::begin (outputNamesChar),
156165 [&](const std::string& str) { return str.c_str (); });
157-
158- // Print names
159- LOG (info) << " \t Input Nodes:" ;
160- for (size_t i = 0 ; i < mInputNames .size (); i++) {
161- LOG (info) << " \t\t " << mInputNames [i] << " : " << printShape (mInputShapes [i]);
162166 }
163-
164- LOG (info) << " \t Output Nodes:" ;
165- for (size_t i = 0 ; i < mOutputNames .size (); i++) {
166- LOG (info) << " \t\t " << mOutputNames [i] << " : " << printShape (mOutputShapes [i]);
167+ if (loggingLevel < 2 ) {
168+ LOG (info) << " (ORT) Model loaded successfully! (input: " << printShape (mInputShapes [0 ]) << " , output: " << printShape (mOutputShapes [0 ]) << " )" ;
167169 }
168170}
169171
@@ -187,36 +189,6 @@ std::vector<O> OrtModel::v2v(std::vector<I>& input, bool clearInput)
187189 }
188190}
189191
190- template <class I , class O > // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
191- std::vector<O> OrtModel::inference (std::vector<I>& input)
192- {
193- std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
194- std::vector<Ort::Value> inputTensor;
195- inputTensor.emplace_back (Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo , reinterpret_cast <O*>(input.data ()), input.size (), inputShape.data (), inputShape.size ()));
196- // input.clear();
197- auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
198- O* outputValues = reinterpret_cast <O*>(outputTensors[0 ].template GetTensorMutableData <O>());
199- std::vector<O> outputValuesVec{outputValues, outputValues + inputShape[0 ] * mOutputShapes [0 ][1 ]};
200- outputTensors.clear ();
201- return outputValuesVec;
202- }
203-
204- template <class I , class O > // class I is the input data type, e.g. float, class O is the output data type, e.g. O2::gpu::OrtDataType::Float16_t from O2/GPU/GPUTracking/ML/convert_float16.h
205- std::vector<O> OrtModel::inference (std::vector<std::vector<I>>& input)
206- {
207- std::vector<Ort::Value> inputTensor;
208- for (auto i : input) {
209- std::vector<int64_t > inputShape{(int64_t )(i.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
210- inputTensor.emplace_back (Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo , reinterpret_cast <O*>(i.data ()), i.size (), inputShape.data (), inputShape.size ()));
211- }
212- // input.clear();
213- auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
214- O* outputValues = reinterpret_cast <O*>(outputTensors[0 ].template GetTensorMutableData <O>());
215- std::vector<O> outputValuesVec{outputValues, outputValues + inputTensor.size () / mInputShapes [0 ][1 ] * mOutputShapes [0 ][1 ]};
216- outputTensors.clear ();
217- return outputValuesVec;
218- }
219-
220192std::string OrtModel::printShape (const std::vector<int64_t >& v)
221193{
222194 std::stringstream ss (" " );
@@ -227,74 +199,68 @@ std::string OrtModel::printShape(const std::vector<int64_t>& v)
227199 return ss.str ();
228200}
229201
230- template <>
231- std::vector<float > OrtModel::inference< float , float > (std::vector<float >& input)
202+ template <class I , class O >
203+ std::vector<O > OrtModel::inference (std::vector<I >& input)
232204{
233205 std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
234206 std::vector<Ort::Value> inputTensor;
235- inputTensor.emplace_back (Ort::Value::CreateTensor<float >(pImplOrt->memoryInfo , input.data (), input.size (), inputShape.data (), inputShape.size ()));
207+ if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
208+ inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(input.data ()), input.size (), inputShape.data (), inputShape.size ()));
209+ } else {
210+ inputTensor.emplace_back (Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo , input.data (), input.size (), inputShape.data (), inputShape.size ()));
211+ }
236212 // input.clear();
237213 auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
238- float * outputValues = outputTensors[0 ].template GetTensorMutableData <float >();
239- std::vector<float > outputValuesVec{outputValues, outputValues + inputShape[0 ] * mOutputShapes [0 ][1 ]};
214+ O * outputValues = outputTensors[0 ].template GetTensorMutableData <O >();
215+ std::vector<O > outputValuesVec{outputValues, outputValues + inputShape[0 ] * mOutputShapes [0 ][1 ]};
240216 outputTensors.clear ();
241217 return outputValuesVec;
242218}
243219
244- template <>
245- std::vector<float > OrtModel::inference<OrtDataType::Float16_t, float >(std::vector<OrtDataType::Float16_t>& input)
246- {
247- std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
248- std::vector<Ort::Value> inputTensor;
249- inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(input.data ()), input.size (), inputShape.data (), inputShape.size ()));
250- // input.clear();
251- auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
252- float * outputValues = outputTensors[0 ].template GetTensorMutableData <float >();
253- std::vector<float > outputValuesVec{outputValues, outputValues + inputShape[0 ] * mOutputShapes [0 ][1 ]};
254- outputTensors.clear ();
255- return outputValuesVec;
256- }
220+ template std::vector<float > OrtModel::inference<float , float >(std::vector<float >&);
257221
258- template <>
259- std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>& input)
260- {
261- std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
262- std::vector<Ort::Value> inputTensor;
263- inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(input.data ()), input.size (), inputShape.data (), inputShape.size ()));
264- // input.clear();
265- auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
266- OrtDataType::Float16_t* outputValues = reinterpret_cast <OrtDataType::Float16_t*>(outputTensors[0 ].template GetTensorMutableData <Ort::Float16_t>());
267- std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0 ] * mOutputShapes [0 ][1 ]};
268- outputTensors.clear ();
269- return outputValuesVec;
270- }
222+ template std::vector<float > OrtModel::inference<OrtDataType::Float16_t, float >(std::vector<OrtDataType::Float16_t>&);
271223
272- template <>
273- std::vector<OrtDataType::Float16_t> OrtModel::inference<float , OrtDataType::Float16_t>(std::vector<float >& input)
224+ template std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<OrtDataType::Float16_t>&);
225+
226+ template <class I , class O >
227+ void OrtModel::inference (I* input, size_t input_size, O* output)
274228{
275- std::vector<int64_t > inputShape{(int64_t )(input.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
276- std::vector<Ort::Value> inputTensor;
277- inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(input.data ()), input.size (), inputShape.data (), inputShape.size ()));
278- // input.clear();
279- auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
280- OrtDataType::Float16_t* outputValues = reinterpret_cast <OrtDataType::Float16_t*>(outputTensors[0 ].template GetTensorMutableData <Ort::Float16_t>());
281- std::vector<OrtDataType::Float16_t> outputValuesVec{outputValues, outputValues + inputShape[0 ] * mOutputShapes [0 ][1 ]};
282- outputTensors.clear ();
283- return outputValuesVec;
229+ std::vector<int64_t > inputShape{(int64_t )(input_size / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
230+ Ort::Value inputTensor = Ort::Value (nullptr );
231+ if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
232+ inputTensor = Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(input), input_size, inputShape.data (), inputShape.size ());
233+ } else {
234+ inputTensor = Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo , input, input_size, inputShape.data (), inputShape.size ());
235+ }
236+
237+ std::vector<int64_t > outputShape{inputShape[0 ], mOutputShapes [0 ][1 ]};
238+ size_t outputSize = (int64_t )(input_size * mOutputShapes [0 ][1 ] / mInputShapes [0 ][1 ]);
239+ Ort::Value outputTensor = Ort::Value::CreateTensor<O>(pImplOrt->memoryInfo , output, outputSize, outputShape.data (), outputShape.size ());
240+
241+ (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), &inputTensor, 1 , outputNamesChar.data (), &outputTensor, outputNamesChar.size ()); // TODO: Not sure if 1 is correct here
284242}
285243
286- template <>
287- std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t, OrtDataType::Float16_t>(std::vector<std::vector<OrtDataType::Float16_t>>& input)
244+ template void OrtModel::inference<OrtDataType::Float16_t, float >(OrtDataType::Float16_t*, size_t , float *);
245+
246+ template void OrtModel::inference<float , float >(float *, size_t , float *);
247+
248+ template <class I , class O >
249+ std::vector<O> OrtModel::inference (std::vector<std::vector<I>>& input)
288250{
289251 std::vector<Ort::Value> inputTensor;
290252 for (auto i : input) {
291253 std::vector<int64_t > inputShape{(int64_t )(i.size () / mInputShapes [0 ][1 ]), (int64_t )mInputShapes [0 ][1 ]};
292- inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(i.data ()), i.size (), inputShape.data (), inputShape.size ()));
254+ if constexpr (std::is_same_v<I, OrtDataType::Float16_t>) {
255+ inputTensor.emplace_back (Ort::Value::CreateTensor<Ort::Float16_t>(pImplOrt->memoryInfo , reinterpret_cast <Ort::Float16_t*>(i.data ()), i.size (), inputShape.data (), inputShape.size ()));
256+ } else {
257+ inputTensor.emplace_back (Ort::Value::CreateTensor<I>(pImplOrt->memoryInfo , i.data (), i.size (), inputShape.data (), inputShape.size ()));
258+ }
293259 }
294260 // input.clear();
295261 auto outputTensors = (pImplOrt->session )->Run (pImplOrt->runOptions , inputNamesChar.data (), inputTensor.data (), inputTensor.size (), outputNamesChar.data (), outputNamesChar.size ());
296- OrtDataType::Float16_t * outputValues = reinterpret_cast <OrtDataType::Float16_t *>(outputTensors[0 ].template GetTensorMutableData <Ort::Float16_t >());
297- std::vector<OrtDataType::Float16_t > outputValuesVec{outputValues, outputValues + inputTensor.size () / mInputShapes [0 ][1 ] * mOutputShapes [0 ][1 ]};
262+ O * outputValues = reinterpret_cast <O *>(outputTensors[0 ].template GetTensorMutableData <O >());
263+ std::vector<O > outputValuesVec{outputValues, outputValues + inputTensor.size () / mInputShapes [0 ][1 ] * mOutputShapes [0 ][1 ]};
298264 outputTensors.clear ();
299265 return outputValuesVec;
300266}
0 commit comments