11// SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2- // Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ // Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33//
44// Licensed under the Apache License, Version 2.0 (the "License");
55// you may not use this file except in compliance with the License.
1717
1818#include " extensions/ess/components/ess_inference.hpp"
1919
20+ #include < dlfcn.h>
2021#include < algorithm>
2122#include < sstream>
2223#include < string>
2526#include " extensions/tensorops/components/ImageUtils.hpp"
2627#include " extensions/tensorops/core/Core.h"
2728#include " extensions/tensorops/core/Image.h"
28- #include " gems/gxf_helpers/expected_macro .hpp"
29+ #include " gems/gxf_helpers/expected_macro_gxf .hpp"
2930#include " gems/hash/hash_file.hpp"
3031#include " gems/video_buffer/allocator.hpp"
3132#include " gxf/cuda/cuda_stream_id.hpp"
@@ -40,46 +41,6 @@ namespace isaac {
4041
4142namespace detail {
4243
43- // Function to bind a cuda stream with cid into downstream message
44- gxf_result_t BindCudaStream (gxf::Entity& message, gxf_uid_t cid) {
45- if (cid == kNullUid ) {
46- GXF_LOG_ERROR (" stream_cid is null" );
47- return GXF_FAILURE;
48- }
49- auto output_stream_id = message.add <gxf::CudaStreamId>(" stream" );
50- if (!output_stream_id) {
51- GXF_LOG_ERROR (" failed to add cudastreamid." );
52- return GXF_FAILURE;
53- }
54- output_stream_id.value ()->stream_cid = cid;
55- return GXF_SUCCESS;
56- }
57-
58- // Function to record a new cuda event
59- gxf_result_t RecordCudaEvent (gxf::Entity& message, gxf::Handle<gxf::CudaStream>& stream) {
60- // Create a new event
61- cudaEvent_t cuda_event;
62- cudaEventCreateWithFlags (&cuda_event, 0 );
63- gxf::CudaEvent event;
64- auto ret = event.initWithEvent (cuda_event, stream->dev_id (), [](auto ) {});
65- if (!ret) {
66- GXF_LOG_ERROR (" failed to init cuda event" );
67- return GXF_FAILURE;
68- }
69- // Record the event
70- // Can define []() { GXF_LOG_DEBUG("tensorops event synced"); }
71- // as callback func for debug purpose
72- ret = stream->record (event.event ().value (),
73- [event = cuda_event, entity = message.clone ().value ()](auto ) {
74- cudaEventDestroy (event);
75- });
76- if (!ret) {
77- GXF_LOG_ERROR (" record event failed" );
78- return ret.error ();
79- }
80- return GXF_SUCCESS;
81- }
82-
8344template <typename T>
8445gxf_result_t PassthroughComponents (gxf::Entity& output, gxf::Entity& input,
8546 const char * name = nullptr ) {
@@ -101,7 +62,7 @@ gxf_result_t PassthroughComponents(gxf::Entity& output, gxf::Entity& input,
10162gxf::Expected<std::string> ComputeEnginePath (
10263 const ess::ModelBuildParams& modelBuildParams) {
10364 const SHA256::String onnx_hash =
104- GXF_UNWRAP_OR_RETURN (hash_file (modelBuildParams.onnx_file_path .c_str ()));
65+ UNWRAP_OR_RETURN (hash_file (modelBuildParams.onnx_file_path .c_str ()));
10566
10667 std::string target_dir = " /tmp" ;
10768
@@ -189,6 +150,11 @@ gxf_result_t ESSInference::registerInterface(gxf::Registrar* registrar) {
189150 result &= registrar->parameter (
190151 onnx_file_path_, " onnx_file_path" , " ONNX file path" ,
191152 " The path to the onnx model file" );
153+ result &= registrar->parameter (
154+ tensorrt_plugin_path_, " tensorrt_plugin" , " TensorRT plugin path" ,
155+ " The path to the TensorRT plugin file" ,
156+ gxf::Registrar::NoDefaultParameter (),
157+ GXF_PARAMETER_FLAGS_OPTIONAL);
192158 result &= registrar->parameter (
193159 enable_fp16_, " enable_fp16" , " Enable FP16" ,
194160 " Flag to enable FP16 in engine generation" , true );
@@ -276,6 +242,14 @@ gxf_result_t ESSInference::start() {
276242 return GXF_FAILURE;
277243 }
278244
245+ // Load ESS plugin
246+ if (tensorrt_plugin_path_.try_get () && !tensorrt_plugin_path_.try_get ().value ().empty ()) {
247+ if (!dlopen (tensorrt_plugin_path_.try_get ().value ().c_str (), RTLD_NOW)) {
248+ GXF_LOG_ERROR (" ESS plugin loading failed." );
249+ return GXF_FAILURE;
250+ }
251+ }
252+
279253 // Setting engine build params
280254 auto maybe_dla_core = dla_core_.try_get ();
281255 const int64_t dla_core = dla_core_.try_get ().value_or (-1 );
@@ -290,7 +264,7 @@ gxf_result_t ESSInference::start() {
290264 engine_file_path = maybe_engine_file_path.value ();
291265 } else {
292266 engine_file_path =
293- GXF_UNWRAP_OR_RETURN (detail::ComputeEnginePath (model_build_params_));
267+ UNWRAP_OR_RETURN (detail::ComputeEnginePath (model_build_params_));
294268 }
295269
296270 // Setting inference params for ESS
@@ -336,41 +310,11 @@ gxf_result_t ESSInference::tick() {
336310 if (!inputLeftMessage) {
337311 return GXF_FAILURE;
338312 }
339- if (cuda_stream != 0 ) {
340- detail::RecordCudaEvent (inputLeftMessage.value (), cuda_stream_);
341- auto inputLeftStreamId = inputLeftMessage.value ().get <gxf::CudaStreamId>(" stream" );
342- if (inputLeftStreamId) {
343- auto inputLeftStream = gxf::Handle<gxf::CudaStream>::Create (
344- inputLeftStreamId.value ().context (),
345- inputLeftStreamId.value ()->stream_cid );
346- // NOTE: This is an expensive call. It will halt the current CPU thread until all events
347- // previously associated with the stream are cleared
348- if (!inputLeftStream.value ()->syncStream ()) {
349- GXF_LOG_ERROR (" sync left stream failed." );
350- return GXF_FAILURE;
351- }
352- }
353- }
354313
355314 auto inputRightMessage = right_image_receiver_->receive ();
356315 if (!inputRightMessage) {
357316 return GXF_FAILURE;
358317 }
359- if (cuda_stream != 0 ) {
360- detail::RecordCudaEvent (inputRightMessage.value (), cuda_stream_);
361- auto inputRightStreamId = inputRightMessage.value ().get <gxf::CudaStreamId>(" stream" );
362- if (inputRightStreamId) {
363- auto inputRightStream = gxf::Handle<gxf::CudaStream>::Create (
364- inputRightStreamId.value ().context (),
365- inputRightStreamId.value ()->stream_cid );
366- // NOTE: This is an expensive call. It will halt the current CPU thread until all events
367- // previously associated with the stream are cleared
368- if (!inputRightStream.value ()->syncStream ()) {
369- GXF_LOG_ERROR (" sync right stream failed." );
370- return GXF_FAILURE;
371- }
372- }
373- }
374318
375319 auto maybeLeftName = left_image_name_.try_get ();
376320 auto leftInputBuffer = inputLeftMessage.value ().get <gxf::VideoBuffer>(
@@ -490,12 +434,6 @@ gxf_result_t ESSInference::tick() {
490434 return GXF_FAILURE;
491435 }
492436
493- // Allocate a cuda event that can be used to record on each tick
494- if (!cuda_stream_.is_null ()) {
495- detail::BindCudaStream (outputMessage.value (), cuda_stream_.cid ());
496- detail::RecordCudaEvent (outputMessage.value (), cuda_stream_);
497- }
498-
499437 // Pass down timestamp if necessary
500438 auto maybeDaqTimestamp =
501439 timestamp_policy_.get () == 0 ? inputLeftMessage.value ().get <gxf::Timestamp>()
@@ -553,12 +491,12 @@ gxf_result_t ESSInference::tick() {
553491 *output_model = *maybe_scaled_model;
554492 *output_conf = *maybe_scaled_model;
555493 } else {
556- GXF_LOG_WARNING (" Input message is missing intrinsics!" );
494+ GXF_LOG_DEBUG (" Input message is missing intrinsics!" );
557495 }
558496
559497 // Publish the data and confidence
560- GXF_RETURN_IF_ERROR (gxf::ToResultCode (output_transmitter_->publish (outputMessage.value ())));
561- GXF_RETURN_IF_ERROR (gxf::ToResultCode (confidence_transmitter_->publish (
498+ RETURN_IF_ERROR (gxf::ToResultCode (output_transmitter_->publish (outputMessage.value ())));
499+ RETURN_IF_ERROR (gxf::ToResultCode (confidence_transmitter_->publish (
562500 outputConfMessage.value ())));
563501
564502 return GXF_SUCCESS;
0 commit comments