@@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status
7979
8080#define TRT_TRY (...) OUTCOME_TRY(trt_try(__VA_ARGS__))
8181
82- TRTNet::~TRTNet () = default ;
82+ TRTNet::~TRTNet () {
83+ CudaDeviceGuard guard (device_);
84+ context_.reset ();
85+ engine_.reset ();
86+ }
8387
8488static Result<DataType> MapDataType (nvinfer1::DataType dtype) {
8589 switch (dtype) {
@@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) {
106110 MMDEPLOY_ERROR (" TRTNet: device must be a GPU!" );
107111 return Status (eNotSupported);
108112 }
113+ CudaDeviceGuard guard (device_);
109114 stream_ = context[" stream" ].get <Stream>();
110115
111116 event_ = Event (device_);
@@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) {
156161 return success ();
157162}
158163
159- Result<void > TRTNet::Deinit () {
160- context_.reset ();
161- engine_.reset ();
162- return success ();
163- }
164+ Result<void > TRTNet::Deinit () { return success (); }
164165
165166Result<void > TRTNet::Reshape (Span<TensorShape> input_shapes) {
167+ CudaDeviceGuard guard (device_);
166168 using namespace trt_detail ;
167169 if (input_shapes.size () != input_tensors_.size ()) {
168170 return Status (eInvalidArgument);
@@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
190192Result<Span<Tensor>> TRTNet::GetOutputTensors () { return output_tensors_; }
191193
192194Result<void > TRTNet::Forward () {
195+ CudaDeviceGuard guard (device_);
193196 using namespace trt_detail ;
194197 std::vector<void *> bindings (engine_->getNbBindings ());
195198
0 commit comments