Skip to content

Commit 992ffd8

Browse files
authored
[Tensorrt] update features and bug fixes (#8961)
#8863 #8943 #8942 #8940 #8954 #8953
1 parent 92bc457 commit 992ffd8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1809
-277
lines changed

lite/backends/nnadapter/nnadapter/include/nnadapter/nnadapter.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,27 @@ typedef enum {
11061106
*/
11071107
NNADAPTER_LOG,
11081108

1109+
/**
1110+
* Computes the log of softmax values for input.
1111+
* The output is calculated using this formula:
1112+
* output = log(exp(input) / reduce_sum(exp(input), axis=axis,
1113+
* keepdims=true))
1114+
*
1115+
* Inputs:
1116+
* * 0: input, a NNADAPTER_FLOAT32,
1117+
* NNADAPTER_QUANT_INT8_SYMM_PER_LAYER tensor.
1118+
* * 1: axis, a NNADAPTER_INT32 scalar. Defaults to 1. It represents the
1119+
* dimension along which softmax will be performed. It should be in range [-R,
1120+
* R), where R is the rank of input, negative value works the same way as
1121+
* axis+R.
1122+
*
1123+
* Outputs:
1124+
* * 0: output, a tensor with the same shape and type as input.
1125+
*
1126+
* Available since version 1.
1127+
*/
1128+
NNADAPTER_LOG_SOFTMAX,
1129+
11091130
/**
11101131
* Applies the Lp Normalization to the input tensor element-wise.
11111132
* The output is calculated using this formula:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
namespace nnadapter {
18+
namespace operation {
19+
20+
#define LOG_SOFTMAX_OPERATION_EXTRACT_INPUTS_OUTPUTS \
21+
auto& input_operands = operation->input_operands; \
22+
auto& output_operands = operation->output_operands; \
23+
auto input_count = input_operands.size(); \
24+
auto output_count = output_operands.size(); \
25+
NNADAPTER_CHECK_EQ(input_count, 2); \
26+
NNADAPTER_CHECK_EQ(output_count, 1); \
27+
/* Input */ \
28+
auto input_operand = input_operands[0]; \
29+
NNADAPTER_VLOG(5) << "input: " << OperandToString(input_operand); \
30+
/* Axis */ \
31+
auto axis = *reinterpret_cast<int32_t*>(input_operands[1]->buffer); \
32+
if (axis < 0) { \
33+
axis += input_operand->type.dimensions.count; \
34+
} \
35+
NNADAPTER_VLOG(5) << "axis=" << axis; \
36+
/* Output */ \
37+
auto output_operand = output_operands[0]; \
38+
NNADAPTER_VLOG(5) << "output: " << OperandToString(output_operand);
39+
40+
} // namespace operation
41+
} // namespace nnadapter
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <cmath>
19+
#include <limits>
20+
#include <vector>
21+
#include "operation/math/dequantize.h"
22+
#include "operation/math/quantize.h"
23+
#include "operation/math/utility.h"
24+
25+
namespace nnadapter {
26+
namespace operation {
27+
namespace math {
28+
29+
template <typename T>
30+
static int log_softmax(const T* input_data,
31+
const std::vector<int32_t>& input_shape,
32+
int axis,
33+
T* output_data) {
34+
if (!input_data || !output_data) {
35+
return -1;
36+
}
37+
auto input_rank = input_shape.size();
38+
if (axis < 0) {
39+
axis += input_rank;
40+
}
41+
auto axis_count = input_shape[axis];
42+
auto outer_count = shape_production(shape_slice(input_shape, 0, axis));
43+
auto inner_count =
44+
shape_production(shape_slice(input_shape, axis + 1, input_rank));
45+
auto compute_count = outer_count * inner_count;
46+
for (int64_t i = 0; i < compute_count; i++) {
47+
auto inner_index = i % inner_count;
48+
auto outer_index = (i / inner_count) * axis_count;
49+
auto start = outer_index * inner_count + inner_index;
50+
auto offset = start;
51+
auto max_value = std::numeric_limits<T>::lowest();
52+
for (int j = 0; j < axis_count; j++) {
53+
max_value =
54+
input_data[offset] > max_value ? input_data[offset] : max_value;
55+
offset += inner_count;
56+
}
57+
offset = start;
58+
T sum_value = 0;
59+
for (int j = 0; j < axis_count; j++) {
60+
output_data[offset] = std::exp(input_data[offset] - max_value);
61+
sum_value += output_data[offset];
62+
offset += inner_count;
63+
}
64+
offset = start;
65+
for (int j = 0; j < axis_count; j++) {
66+
output_data[offset] /= sum_value;
67+
output_data[offset] = std::log(output_data[offset]);
68+
offset += inner_count;
69+
}
70+
}
71+
return 0;
72+
}
73+
74+
int log_softmax(const int8_t* input_data,
75+
const std::vector<int32_t>& input_shape,
76+
float input_scale,
77+
int axis,
78+
int8_t* output_data,
79+
float output_scale);
80+
81+
} // namespace math
82+
} // namespace operation
83+
} // namespace nnadapter

lite/backends/nnadapter/nnadapter/include/nnadapter/operation/split.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ namespace operation {
4242
} else { \
4343
NNADAPTER_VLOG(5) << "axis: " << OperandToString(axis_operand); \
4444
} \
45+
NNADAPTER_CHECK_LT(axis, input_operand->type.dimensions.count); \
4546
/* Split */ \
4647
auto split_operand = input_operands[2]; \
4748
std::vector<int> split; \

lite/backends/nnadapter/nnadapter/src/driver/huawei_ascend_npu/converter/all.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ REGISTER_CONVERTER(LEAKY_RELU, ConvertLeakyRelu)
5656
REGISTER_CONVERTER(LESS, ConvertComparisons)
5757
REGISTER_CONVERTER(LESS_EQUAL, ConvertComparisons)
5858
REGISTER_CONVERTER(LOG, ConvertUnaryActivations)
59+
REGISTER_CONVERTER(LOG_SOFTMAX, ConvertLogSoftmax)
5960
REGISTER_CONVERTER(LP_NORMALIZATION, ConvertLpNormalization)
6061
REGISTER_CONVERTER(MAT_MUL, ConvertMatMul)
6162
REGISTER_CONVERTER(MAX, ConvertElementwise)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "operation/log_softmax.h"
16+
#include "driver/huawei_ascend_npu/converter/converter.h"
17+
#include "utility/debug.h"
18+
#include "utility/logging.h"
19+
20+
namespace nnadapter {
21+
namespace huawei_ascend_npu {
22+
23+
int ConvertLogSoftmax(Converter* converter, core::Operation* operation) {
24+
LOG_SOFTMAX_OPERATION_EXTRACT_INPUTS_OUTPUTS
25+
26+
// Convert to GE operators
27+
auto input_operator = converter->GetMappedOperator(input_operand);
28+
if (!input_operator) {
29+
input_operator = converter->ConvertOperand(input_operand);
30+
}
31+
auto log_softmax_op =
32+
converter->AddOperator<ge::op::LogSoftmaxV2>(output_operand);
33+
log_softmax_op->set_attr_axes({axis});
34+
SET_INPUT(log_softmax_op, logits, input_operator);
35+
MAP_OUTPUT(log_softmax_op, logsoftmax, output_operand);
36+
return NNADAPTER_NO_ERROR;
37+
}
38+
39+
} // namespace huawei_ascend_npu
40+
} // namespace nnadapter

lite/backends/nnadapter/nnadapter/src/driver/nvidia_tensorrt/calibrator.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ Int8EntropyCalibrator::Int8EntropyCalibrator(int batch_size,
5454

5555
bool Int8EntropyCalibrator::getBatch(void* bindings[],
5656
const char* names[],
57-
int nbBindings) {
57+
int nb_bindings) TRT_NOEXCEPT {
5858
// TODO(zhupengyang): support multi inputs
59-
NNADAPTER_CHECK_EQ(nbBindings, 1);
59+
NNADAPTER_CHECK_EQ(nb_bindings, 1);
6060
if (static_cast<size_t>(index_) >= input_file_names_.at(0).size()) {
6161
return false;
6262
}
@@ -89,7 +89,8 @@ bool Int8EntropyCalibrator::getBatch(void* bindings[],
8989
return true;
9090
}
9191

92-
const void* Int8EntropyCalibrator::readCalibrationCache(size_t& length) {
92+
const void* Int8EntropyCalibrator::readCalibrationCache(size_t& length)
93+
TRT_NOEXCEPT {
9394
if (table_path_.empty()) {
9495
NNADAPTER_LOG(WARNING) << "No calibration table file is set. New "
9596
"calibration table will be generated.";
@@ -106,7 +107,7 @@ const void* Int8EntropyCalibrator::readCalibrationCache(size_t& length) {
106107
}
107108

108109
void Int8EntropyCalibrator::writeCalibrationCache(const void* cache,
109-
size_t length) {
110+
size_t length) TRT_NOEXCEPT {
110111
if (table_path_.empty()) {
111112
NNADAPTER_LOG(WARNING) << "No calibration table will be saved because "
112113
"table_path is not found.";

lite/backends/nnadapter/nnadapter/src/driver/nvidia_tensorrt/calibrator.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ class Int8EntropyCalibrator : public nvinfer1::IInt8EntropyCalibrator2 {
2727
Int8EntropyCalibrator(int batch_size,
2828
std::string dataset_path,
2929
std::string table_path);
30-
virtual ~Int8EntropyCalibrator() {}
30+
virtual ~Int8EntropyCalibrator() TRT_NOEXCEPT {}
3131

32-
int getBatchSize() const override { return batch_size_; }
33-
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
34-
const void* readCalibrationCache(size_t& length) override;
35-
void writeCalibrationCache(const void* cache, size_t length) override;
32+
int getBatchSize() const TRT_NOEXCEPT override { return batch_size_; }
33+
bool getBatch(void* bindings[],
34+
const char* names[],
35+
int nb_bindings) TRT_NOEXCEPT override;
36+
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
37+
void writeCalibrationCache(const void* cache,
38+
size_t length) TRT_NOEXCEPT override;
3639

3740
private:
3841
int batch_size_{1};

lite/backends/nnadapter/nnadapter/src/driver/nvidia_tensorrt/converter/all.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ REGISTER_CONVERTER(CAST, ConvertCast)
2424
REGISTER_CONVERTER(CLIP, ConvertClip)
2525
REGISTER_CONVERTER(CONCAT, ConvertConcat)
2626
REGISTER_CONVERTER(CONV_2D, ConvertConv2D)
27+
REGISTER_CONVERTER(CONV_2D_TRANSPOSE, ConvertConv2DTranspose)
2728
REGISTER_CONVERTER(DIV, ConvertElementwise)
2829
REGISTER_CONVERTER(EQUAL, ConvertComparisons)
2930
REGISTER_CONVERTER(EXP, ConvertUnaryOperations)
@@ -33,6 +34,7 @@ REGISTER_CONVERTER(FULLY_CONNECTED, ConvertFullyConnected)
3334
REGISTER_CONVERTER(HARD_SWISH, ConvertHardSwish)
3435
REGISTER_CONVERTER(LEAKY_RELU, ConvertLeakyRelu)
3536
REGISTER_CONVERTER(LOG, ConvertUnaryOperations)
37+
REGISTER_CONVERTER(LOG_SOFTMAX, ConvertLogSoftmax)
3638
REGISTER_CONVERTER(MAT_MUL, ConvertMatMul)
3739
REGISTER_CONVERTER(MAX_POOL_2D, ConvertPool2D)
3840
REGISTER_CONVERTER(MUL, ConvertElementwise)
@@ -49,6 +51,7 @@ REGISTER_CONVERTER(SIGMOID, ConvertActivations)
4951
REGISTER_CONVERTER(SLICE, ConvertSlice)
5052
REGISTER_CONVERTER(SOFTMAX, ConvertSoftmax)
5153
REGISTER_CONVERTER(SQUEEZE, ConvertSqueeze)
54+
REGISTER_CONVERTER(SPLIT, ConvertSplit)
5255
REGISTER_CONVERTER(STACK, ConvertStack)
5356
REGISTER_CONVERTER(SUB, ConvertElementwise)
5457
REGISTER_CONVERTER(SWISH, ConvertSwish)

lite/backends/nnadapter/nnadapter/src/driver/nvidia_tensorrt/converter/batch_normalization.cc

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,37 @@ int ConvertBatchNormalization(Converter* converter,
3939
NNADAPTER_CHECK(bias_ptr);
4040
NNADAPTER_CHECK(mean_ptr);
4141
NNADAPTER_CHECK(var_ptr);
42-
// prepare data
43-
auto x_dim = input_operand->type.dimensions;
44-
NNADAPTER_CHECK_EQ(scale_operand->type.dimensions.data[0], x_dim.data[1]);
45-
NNADAPTER_CHECK_EQ(bias_operand->type.dimensions.data[0], x_dim.data[1]);
46-
NNADAPTER_CHECK_EQ(mean_operand->type.dimensions.data[0], x_dim.data[1]);
47-
NNADAPTER_CHECK_EQ(variance_operand->type.dimensions.data[0], x_dim.data[1]);
48-
std::vector<float> fuse_scale(x_dim.data[1], 0);
49-
std::vector<float> fuse_bias(x_dim.data[1], 0);
42+
auto input_tensor_dim = input_tensor->getDimensions();
43+
// Add shuffle operator to reshape data into 3 dimensions
44+
if (input_tensor_dim.nbDims < 3) {
45+
nvinfer1::Dims unsqueeze_shape;
46+
unsqueeze_shape.nbDims = 3;
47+
for (int i = 0; i < 3; i++) {
48+
if (i < input_tensor_dim.nbDims) {
49+
unsqueeze_shape.d[i] =
50+
input_tensor_dim.d[i] < 0 ? 0 : input_tensor_dim.d[i];
51+
} else {
52+
unsqueeze_shape.d[i] = 1;
53+
}
54+
}
55+
auto unsqueeze_layer = converter->network()->addShuffle(*input_tensor);
56+
unsqueeze_layer->setReshapeDimensions(unsqueeze_shape);
57+
input_tensor = unsqueeze_layer->getOutput(0);
58+
}
59+
// Add batch_normalization op using ScaleNd operator
60+
NNADAPTER_CHECK_EQ(scale_operand->type.dimensions.data[0],
61+
input_tensor_dim.d[0]);
62+
NNADAPTER_CHECK_EQ(bias_operand->type.dimensions.data[0],
63+
input_tensor_dim.d[0]);
64+
NNADAPTER_CHECK_EQ(mean_operand->type.dimensions.data[0],
65+
input_tensor_dim.d[0]);
66+
NNADAPTER_CHECK_EQ(variance_operand->type.dimensions.data[0],
67+
input_tensor_dim.d[0]);
68+
std::vector<float> fuse_scale(input_tensor_dim.d[0], 0);
69+
std::vector<float> fuse_bias(input_tensor_dim.d[0], 0);
5070
auto fuse_scale_ptr = fuse_scale.data();
5171
auto fuse_bias_ptr = fuse_bias.data();
52-
for (int i = 0; i < x_dim.data[1]; i++) {
72+
for (int i = 0; i < input_tensor_dim.d[0]; i++) {
5373
fuse_scale_ptr[i] = scale_ptr[i] / sqrtf(var_ptr[i] + epsilon);
5474
fuse_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * fuse_scale_ptr[i];
5575
}
@@ -58,9 +78,9 @@ int ConvertBatchNormalization(Converter* converter,
5878
const float* power_ptr = nullptr;
5979
// add scale op
6080
nvinfer1::Weights scale_w =
61-
converter->AddWeights(fuse_scale_ptr_const, x_dim.data[1]);
81+
converter->AddWeights(fuse_scale_ptr_const, input_tensor_dim.d[0]);
6282
nvinfer1::Weights shift_w =
63-
converter->AddWeights(fuse_bias_ptr_const, x_dim.data[1]);
83+
converter->AddWeights(fuse_bias_ptr_const, input_tensor_dim.d[0]);
6484
nvinfer1::Weights power_w = converter->AddWeights(power_ptr, 0);
6585
auto layer = converter->network()->addScaleNd(*input_tensor,
6686
nvinfer1::ScaleMode::kCHANNEL,
@@ -69,6 +89,19 @@ int ConvertBatchNormalization(Converter* converter,
6989
power_w,
7090
0);
7191
auto output_tensor = layer->getOutput(0);
92+
// Add shuffle operator to recover shape
93+
if (input_tensor_dim.nbDims < 3) {
94+
nvinfer1::Dims squeeze_shape;
95+
squeeze_shape.nbDims = input_tensor_dim.nbDims;
96+
for (int i = 0; i < squeeze_shape.nbDims; i++) {
97+
squeeze_shape.d[i] =
98+
input_tensor_dim.d[i] < 0 ? 0 : input_tensor_dim.d[i];
99+
}
100+
auto squeeze_layer =
101+
converter->network()->addShuffle(*(layer->getOutput(0)));
102+
squeeze_layer->setReshapeDimensions(squeeze_shape);
103+
output_tensor = squeeze_layer->getOutput(0);
104+
}
72105
converter->UpdateTensorMap(output_operand, output_tensor);
73106
return NNADAPTER_NO_ERROR;
74107
}

0 commit comments

Comments
 (0)