Skip to content

Commit 895542c

Browse files
committed
Support dynamic input in expand layer, expand_as layer and repeat layer
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 2b50334 commit 895542c

File tree

5 files changed

+949
-40
lines changed

5 files changed

+949
-40
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 184 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "core/conversion/tensorcontainer/TensorContainer.h"
44
#include "core/util/prelude.h"
55
#include "core/util/trt_util.h"
6+
#include "plugins/checkshape_plugin.h"
67
#include "torch/torch.h"
78

89
#include <ATen/ATen.h>
@@ -15,24 +16,93 @@ namespace converters {
1516
namespace impl {
1617
namespace {
1718

19+
nvinfer1::ILayer* create_plugin(
20+
ConversionCtx* ctx,
21+
const torch::jit::Node* n,
22+
nvinfer1::ITensor* inShape,
23+
nvinfer1::ITensor* expandShape,
24+
int32_t in_rank,
25+
int32_t expand_rank,
26+
const char* name) {
27+
auto creator = new plugins::CheckShapePluginCreator();
28+
std::vector<nvinfer1::PluginField> fields;
29+
nvinfer1::PluginField input_rank("input_rank", &in_rank, nvinfer1::PluginFieldType::kINT32, 1);
30+
nvinfer1::PluginField output_rank("expand_rank", &expand_rank, nvinfer1::PluginFieldType::kINT32, 1);
31+
fields.push_back(input_rank);
32+
fields.push_back(output_rank);
33+
nvinfer1::PluginFieldCollection collection;
34+
collection.nbFields = fields.size();
35+
collection.fields = fields.data();
36+
auto plugin = creator->createPlugin(name, &collection);
37+
38+
nvinfer1::ITensor* inputs[] = {inShape, expandShape};
39+
auto expandShape_layer = ctx->net->addPluginV2(inputs, 2, *plugin);
40+
TRTORCH_CHECK(expandShape_layer, "Unable to create interpolation plugin from node" << *n);
41+
42+
expandShape_layer->setName("CheckShapePlugin");
43+
return expandShape_layer;
44+
}
45+
46+
void addSliceInput(nvinfer1::Dims& dims, int idx, ConversionCtx* ctx, nvinfer1::ISliceLayer* slice) {
47+
int32_t rank = static_cast<int32_t>(dims.nbDims);
48+
int32_t* tmp = new int32_t[rank];
49+
for(int i=0;i<rank;i++)
50+
tmp[i] = dims.d[i];
51+
const nvinfer1::Dims d{1, {rank}};
52+
const nvinfer1::Weights w{nvinfer1::DataType::kINT32, tmp, rank};
53+
auto t = ctx->net->addConstant(d, w)->getOutput(0);
54+
slice->setInput(idx, *t);
55+
}
56+
57+
nvinfer1::ITensor* vec2Tensor(int32_t *dim, int rank, ConversionCtx* ctx){
58+
const nvinfer1::Dims d{1, {static_cast<int32_t>(rank)}};
59+
const nvinfer1::Weights w{nvinfer1::DataType::kINT32, dim, rank};
60+
return ctx->net->addConstant(d, w)->getOutput(0);
61+
}
62+
63+
nvinfer1::ITensor * concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor*tensor){
64+
if(max_rank - old_rank > 0){
65+
int32_t* tmp = new int32_t[max_rank - old_rank];
66+
for(int i=0;i<(max_rank - old_rank);i++)
67+
tmp[i] = 1;
68+
auto max_rank_tensor = vec2Tensor(tmp, max_rank - old_rank, ctx);
69+
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
70+
nvinfer1::ITensor* const args[2] = {max_rank_tensor, in_shape_tensor};
71+
return ctx->net->addConcatenation(args, 2)->getOutput(0);
72+
}else{ // max_rank - old_rank == 0
73+
return ctx->net->addShape(*tensor)->getOutput(0);
74+
}
75+
}
76+
1877
bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
1978
auto input_dims = in->getDimensions();
2079
TRTORCH_CHECK(
2180
input_dims.nbDims <= expandedDims.nbDims,
2281
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
2382

2483
// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
25-
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
84+
for (int i = expandedDims.nbDims - 1; i >= 0; --i) {
2685
int64_t offset = expandedDims.nbDims - 1 - i;
2786
int64_t dim = input_dims.nbDims - 1 - offset;
2887
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
2988
int64_t targetSize = expandedDims.d[i];
30-
if (size != targetSize) {
31-
if (size != 1) {
32-
TRTORCH_THROW_ERROR(
33-
"The expanded size of tensor (" << targetSize << ")"
34-
<< " must match the existing size (" << size << ")"
35-
<< " at dimension " << i);
89+
if(targetSize != -1){
90+
if (size != targetSize) {
91+
if (size != 1) {
92+
TRTORCH_THROW_ERROR(
93+
"The expanded size of tensor (" << targetSize << ")"
94+
<< " must match the existing size (" << size << ")"
95+
<< " at dimension " << i);
96+
}
97+
}
98+
}else{
99+
if(dim < 0){
100+
TRTORCH_THROW_ERROR("The expanded size of the tensor (" << \
101+
targetSize << ") isn't allowed in a leading, non-existing dimension " << \
102+
i);
103+
}else{
104+
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
105+
expandedDims.d[i] = input_dims.d[dim];
36106
}
37107
}
38108
}
@@ -41,10 +111,10 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
41111
if (num_expand_dims > 0) {
42112
nvinfer1::Dims reshape_dims;
43113
reshape_dims.nbDims = expandedDims.nbDims;
44-
for (int64_t i = 0; i < num_expand_dims; i++) {
114+
for (int i = 0; i < num_expand_dims; i++) {
45115
reshape_dims.d[i] = 1;
46116
}
47-
for (int64_t i = 0; i < input_dims.nbDims; i++) {
117+
for (int i = 0; i < input_dims.nbDims; i++) {
48118
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
49119
}
50120
// Add a reshape layer to expand dims
@@ -60,7 +130,7 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
60130

61131
// Set the stride of non singleton dimension to 1
62132
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
63-
for (int64_t i = 0; i < expandedDims.nbDims; i++) {
133+
for (int i = 0; i < expandedDims.nbDims; i++) {
64134
strides_vec[i] = (in->getDimensions().d[i] != 1);
65135
}
66136

@@ -76,6 +146,61 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
76146
return true;
77147
}
78148

149+
bool add_expand_dynamic(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::ITensor* expandedDimsTensor){
150+
auto input_shape_tensor = ctx->net->addShape(*in)->getOutput(0);
151+
auto input_rank = in->getDimensions().nbDims;
152+
auto output_rank = expandedDimsTensor->getDimensions().d[0];
153+
TRTORCH_CHECK(
154+
input_rank <= output_rank,
155+
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
156+
157+
// add a plugin to check expandedDimsTensor whether match input_shape_tensor
158+
auto expandShape_layer = create_plugin(ctx, n, input_shape_tensor, expandedDimsTensor, input_rank, output_rank, "expandShape");
159+
auto _tensor = expandShape_layer->getOutput(0);
160+
161+
size_t max_rank = std::max(input_rank, output_rank);
162+
163+
// Dimensions are right alignment
164+
auto new_input_shape_tensor = concat(max_rank, input_rank, ctx, in);
165+
// LOG_DEBUG("Expand layer output tensor shape: " << new_output_shape_tensor->getDimensions());
166+
auto new_output_shape_tensor = expandedDimsTensor;
167+
168+
// Add a reshape layer to expand dims
169+
auto shuffle = ctx->net->addShuffle(*in);
170+
shuffle->setInput(1, *new_input_shape_tensor);
171+
172+
// Start the slicing from beginning of tensor since this is an expand layer
173+
std::vector<int64_t> start_vec(max_rank, 0);
174+
nvinfer1::Dims starts_dim = util::toDims(c10::IntArrayRef(start_vec));
175+
176+
// compute sizes = max(x,y).
177+
auto sizes = ctx->net->addElementWise(*new_input_shape_tensor, *new_output_shape_tensor, nvinfer1::ElementWiseOperation::kMAX)->getOutput(0);
178+
nvinfer1::Dims sizes_dim{-1, {}};
179+
sizes_dim.nbDims = max_rank;
180+
181+
// Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
182+
// min(1, sub(input_shape, 1))
183+
int32_t* one_vector_tmp = new int32_t[1];
184+
one_vector_tmp[0] = 1;
185+
auto one_vector = vec2Tensor(one_vector_tmp, 1, ctx);
186+
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_vector, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
187+
auto strides = ctx->net->addElementWise(*one_vector, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
188+
nvinfer1::Dims strides_dim{-1, {}};
189+
strides_dim.nbDims = max_rank;
190+
191+
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDimsTensor
192+
auto slice = ctx->net->addSlice(*shuffle->getOutput(0), starts_dim, sizes_dim, strides_dim);
193+
addSliceInput(starts_dim, 1, ctx, slice);
194+
slice->setInput(2, *sizes);
195+
slice->setInput(3, *strides);
196+
197+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice->getOutput(0));
198+
199+
LOG_DEBUG("Expand layer output tensor shape: " << out_tensor->getDimensions());
200+
201+
return true;
202+
}
203+
79204
auto expand_registrations TRTORCH_UNUSED =
80205
RegisterNodeConversionPatterns()
81206
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
@@ -85,51 +210,75 @@ auto expand_registrations TRTORCH_UNUSED =
85210
auto expanded_size = args[1].unwrapToIntList();
86211
auto expandedDims = util::toDims(expanded_size);
87212
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
88-
return add_expand(ctx, n, in, expandedDims);
213+
if(ctx->input_is_dynamic){
214+
int expanded_size_rank = static_cast<int>(expanded_size.size());
215+
int32_t* tmp = new int32_t[expanded_size_rank];
216+
for(int i=0;i<expanded_size_rank;i++)
217+
tmp[i] = expanded_size[i];
218+
auto expandedDimsTensor = vec2Tensor(tmp, expanded_size_rank, ctx);
219+
return add_expand_dynamic(ctx, n, in, expandedDimsTensor);
220+
}else{
221+
return add_expand(ctx, n, in, expandedDims);
222+
}
89223
}})
90224
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
91225
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92-
// TODO: Currently expand supports static shapes. Need to explore if the same code can be extended
93-
// to dynamic expansion.
94226
auto in = args[0].ITensor();
95227
auto input_dims = in->getDimensions();
96228
auto targetTensor = args[1].ITensor();
97229
auto targetDims = targetTensor->getDimensions();
98230
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
99-
return add_expand(ctx, n, in, targetDims);
231+
if(ctx->input_is_dynamic){
232+
return add_expand_dynamic(ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0));
233+
}else{
234+
return add_expand(ctx, n, in, targetDims);
235+
}
236+
100237
}})
101238
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
102239
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103240
auto in = args[0].ITensor();
104241
auto input_dims = in->getDimensions();
105242
auto repeats = args[1].unwrapToIntList().vec();
243+
int repeats_rank = repeats.size();
106244
TRTORCH_CHECK(
107-
static_cast<int64_t>(repeats.size()) >= input_dims.nbDims,
245+
repeats_rank >= input_dims.nbDims,
108246
"Number of repeat dimensions cannot be smaller than number of input dimensions");
109-
auto num_expand_dims = repeats.size() - input_dims.nbDims;
110-
if (num_expand_dims > 0) {
111-
nvinfer1::Dims reshape_dims;
112-
reshape_dims.nbDims = repeats.size();
113-
for (size_t i = 0; i < num_expand_dims; i++) {
114-
reshape_dims.d[i] = 1;
115-
}
116-
for (int64_t i = 0; i < input_dims.nbDims; i++) {
117-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
118-
}
247+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
248+
249+
if(ctx->input_is_dynamic){
250+
int input_rank = input_dims.nbDims;
251+
int output_rank= repeats_rank;
252+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
253+
119254
// Add a reshape layer to expand dims
120-
auto reshape_layer = ctx->net->addShuffle(*in);
121-
reshape_layer->setReshapeDimensions(reshape_dims);
122-
in = reshape_layer->getOutput(0);
123-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
255+
auto shuffle = ctx->net->addShuffle(*in);
256+
shuffle->setInput(1, *new_input_shape_tensor);
257+
in = shuffle->getOutput(0);
258+
}else{
259+
if (num_expand_dims > 0) {
260+
nvinfer1::Dims reshape_dims;
261+
reshape_dims.nbDims = repeats.size();
262+
for (int i = 0; i < num_expand_dims; i++) {
263+
reshape_dims.d[i] = 1;
264+
}
265+
for (int i = 0; i < input_dims.nbDims; i++) {
266+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
267+
}
268+
// Add a reshape layer to expand dims
269+
auto reshape_layer = ctx->net->addShuffle(*in);
270+
reshape_layer->setReshapeDimensions(reshape_dims);
271+
in = reshape_layer->getOutput(0);
272+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
273+
}
274+
LOG_DEBUG("Repeats: " << repeats);
124275
}
125276

126-
LOG_DEBUG("Repeats: " << repeats);
127-
128277
// Concat across all repeat axes.
129278
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
130-
for (int64_t i = repeats.size() - 1; i >= 0; --i) {
279+
for (int i = repeats.size() - 1; i >= 0; --i) {
131280
std::vector<nvinfer1::ITensor*> tensors_vec;
132-
for (int64_t j = 0; j < repeats[i]; j++) {
281+
for (int j = 0; j < repeats[i]; j++) {
133282
tensors_vec.push_back(in);
134283
}
135284
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
@@ -139,8 +288,7 @@ auto expand_registrations TRTORCH_UNUSED =
139288

140289
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
141290

142-
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
143-
291+
LOG_DEBUG("Repeat layer output tensor shape: " << in->getDimensions());
144292
return true;
145293
}});
146294

@@ -149,4 +297,4 @@ auto expand_registrations TRTORCH_UNUSED =
149297
} // namespace converters
150298
} // namespace conversion
151299
} // namespace core
152-
} // namespace trtorch
300+
} // namespace trtorch

core/conversion/converters/impl/plugins/BUILD

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ config_setting(
1010
cc_library(
1111
name = "plugins",
1212
hdrs = [
13-
"interpolate_plugin.h"
13+
"interpolate_plugin.h",
14+
"checkshape_plugin.h"
1415
],
1516
srcs = [
16-
"interpolate_plugin.cpp"
17+
"interpolate_plugin.cpp",
18+
"checkshape_plugin.cpp"
1719
],
1820
deps = [
1921
"@tensorrt//:nvinfer",
@@ -37,5 +39,5 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
3739
pkg_tar(
3840
name = "include",
3941
package_dir = "core/conversion/converters/impl/plugins",
40-
srcs = ["interpolate_plugin.h"],
42+
srcs = ["interpolate_plugin.h", "checkshape_plugin.h"],
4143
)

0 commit comments

Comments
 (0)