@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include < cuda.h>
16
+ #include < cuda_runtime_api.h>
15
17
#include < glog/logging.h>
16
18
#include < gtest/gtest.h>
17
19
#include " NvInfer.h"
18
- #include " cuda.h"
19
- #include " cuda_runtime_api.h"
20
20
#include " paddle/fluid/platform/dynload/tensorrt.h"
21
21
22
22
namespace dy = paddle::platform::dynload;
@@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {
43
43
44
44
class ScopedWeights {
45
45
public:
46
- ScopedWeights (float value) : value_(value) {
46
+ explicit ScopedWeights (float value) : value_(value) {
47
47
w.type = nvinfer1::DataType::kFLOAT ;
48
48
w.values = &value_;
49
49
w.count = 1 ;
@@ -58,13 +58,13 @@ class ScopedWeights {
58
58
// The following two API are implemented in TensorRT's header file, cannot load
59
59
// from the dynamic library. So create our own implementation and directly
60
60
// trigger the method from the dynamic library.
61
- nvinfer1::IBuilder* createInferBuilder (nvinfer1::ILogger& logger) {
61
+ nvinfer1::IBuilder* createInferBuilder (nvinfer1::ILogger* logger) {
62
62
return static_cast <nvinfer1::IBuilder*>(
63
- dy::createInferBuilder_INTERNAL (& logger, NV_TENSORRT_VERSION));
63
+ dy::createInferBuilder_INTERNAL (logger, NV_TENSORRT_VERSION));
64
64
}
65
- nvinfer1::IRuntime* createInferRuntime (nvinfer1::ILogger& logger) {
65
+ nvinfer1::IRuntime* createInferRuntime (nvinfer1::ILogger* logger) {
66
66
return static_cast <nvinfer1::IRuntime*>(
67
- dy::createInferRuntime_INTERNAL (& logger, NV_TENSORRT_VERSION));
67
+ dy::createInferRuntime_INTERNAL (logger, NV_TENSORRT_VERSION));
68
68
}
69
69
70
70
const char * kInputTensor = " input" ;
@@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
74
74
nvinfer1::IHostMemory* CreateNetwork () {
75
75
Logger logger;
76
76
// Create the engine.
77
- nvinfer1::IBuilder* builder = createInferBuilder (logger);
77
+ nvinfer1::IBuilder* builder = createInferBuilder (& logger);
78
78
ScopedWeights weights (2 .);
79
79
ScopedWeights bias (3 .);
80
80
@@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
103
103
return model;
104
104
}
105
105
106
- void Execute (nvinfer1::IExecutionContext& context, const float * input,
106
+ void Execute (nvinfer1::IExecutionContext* context, const float * input,
107
107
float * output) {
108
- const nvinfer1::ICudaEngine& engine = context. getEngine ();
108
+ const nvinfer1::ICudaEngine& engine = context-> getEngine ();
109
109
// Two binds, input and output
110
110
ASSERT_EQ (engine.getNbBindings (), 2 );
111
111
const int input_index = engine.getBindingIndex (kInputTensor );
@@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
119
119
// Copy the input to the GPU, execute the network, and copy the output back.
120
120
ASSERT_EQ (0 , cudaMemcpyAsync (buffers[input_index], input, sizeof (float ),
121
121
cudaMemcpyHostToDevice, stream));
122
- context. enqueue (1 , buffers, stream, nullptr );
122
+ context-> enqueue (1 , buffers, stream, nullptr );
123
123
ASSERT_EQ (0 , cudaMemcpyAsync (output, buffers[output_index], sizeof (float ),
124
124
cudaMemcpyDeviceToHost, stream));
125
125
cudaStreamSynchronize (stream);
@@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {
136
136
137
137
// Use the model to create an engine and an execution context.
138
138
Logger logger;
139
- nvinfer1::IRuntime* runtime = createInferRuntime (logger);
139
+ nvinfer1::IRuntime* runtime = createInferRuntime (& logger);
140
140
nvinfer1::ICudaEngine* engine =
141
141
runtime->deserializeCudaEngine (model->data (), model->size (), nullptr );
142
142
model->destroy ();
@@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
145
145
// Execute the network.
146
146
float input = 1234 ;
147
147
float output;
148
- Execute (* context, &input, &output);
148
+ Execute (context, &input, &output);
149
149
EXPECT_EQ (output, input * 2 + 3 );
150
150
151
151
// Destroy the engine.
0 commit comments