Skip to content

Commit 37305c0

Browse files
authored
Merge pull request #315 from guoruoqian/expandDynamic
Support dynamic input in expand layer, expand_as layer and repeat layer
2 parents 0b6a3fd + 5b0f584 commit 37305c0

File tree

3 files changed

+458
-73
lines changed

3 files changed

+458
-73
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 212 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ namespace converters {
1515
namespace impl {
1616
namespace {
1717

18+
nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor* tensor) {
19+
if (max_rank - old_rank > 0) {
20+
torch::Tensor thOne = torch::tensor(std::vector<int32_t>(max_rank - old_rank, 1), torch::kInt32);
21+
auto one_tensor = tensor_to_const(ctx, thOne);
22+
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
23+
nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor};
24+
return ctx->net->addConcatenation(args, 2)->getOutput(0);
25+
} else { // max_rank - old_rank == 0
26+
return ctx->net->addShape(*tensor)->getOutput(0);
27+
}
28+
}
29+
1830
bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
1931
auto input_dims = in->getDimensions();
2032
TRTORCH_CHECK(
@@ -27,12 +39,26 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
2739
int64_t dim = input_dims.nbDims - 1 - offset;
2840
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
2941
int64_t targetSize = expandedDims.d[i];
30-
if (size != targetSize) {
31-
if (size != 1) {
42+
// In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
43+
if (targetSize != -1) {
44+
if (size != targetSize) {
45+
if (size != 1) {
46+
TRTORCH_THROW_ERROR(
47+
"The expanded size of tensor (" << targetSize << ")"
48+
<< " must match the existing size (" << size << ")"
49+
<< " at dimension " << i);
50+
}
51+
}
52+
} else {
53+
// For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
54+
// not [-1, 3, 4].
55+
if (dim < 0) {
3256
TRTORCH_THROW_ERROR(
33-
"The expanded size of tensor (" << targetSize << ")"
34-
<< " must match the existing size (" << size << ")"
35-
<< " at dimension " << i);
57+
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
58+
<< i);
59+
} else {
60+
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
61+
expandedDims.d[i] = input_dims.d[dim];
3662
}
3763
}
3864
}
@@ -76,77 +102,192 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
76102
return true;
77103
}
78104

105+
bool add_expand_dynamic(
106+
ConversionCtx* ctx,
107+
const torch::jit::Node* n,
108+
nvinfer1::ITensor* in,
109+
nvinfer1::ITensor* expandedDimsTensor,
110+
nvinfer1::Dims expandedDims,
111+
bool is_expand_layer) {
112+
auto input_dims = in->getDimensions();
113+
auto input_rank = in->getDimensions().nbDims;
114+
auto output_rank = expandedDims.nbDims;
115+
TRTORCH_CHECK(
116+
input_rank <= output_rank,
117+
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
118+
119+
/* TODO: When the inputs are dynamic, some dimensions of the inputs are indeterminate before setBindingDimensions. For
120+
these indeterminate dimensions, we don't validate the expansion. Eg: For an input of [3, -1], we omit the
121+
validation of the second dimension. Need to explore a better way to validate the expansion.
122+
*/
123+
// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
124+
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
125+
int64_t offset = expandedDims.nbDims - 1 - i;
126+
int64_t dim = input_dims.nbDims - 1 - offset;
127+
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
128+
int64_t targetSize = expandedDims.d[i];
129+
// Passing -1 as the size for a dimension means not changing the size of that dimension in expand layer.
130+
if (targetSize != -1) {
131+
if (size != targetSize) {
132+
// if size == -1, we can't validate the expansion before setBindingDimensions.
133+
if (!(size == -1 || size == 1)) {
134+
TRTORCH_THROW_ERROR(
135+
"The expanded size of tensor (" << targetSize << ")"
136+
<< " must match the existing size (" << size << ")"
137+
<< " at dimension " << i);
138+
}
139+
}
140+
} else {
141+
// In dynamic expand layer, for the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be
142+
// expanded to [3, -1, 4] but not [-1, 3, 4].
143+
if (is_expand_layer && dim < 0) {
144+
TRTORCH_THROW_ERROR(
145+
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
146+
<< i);
147+
}
148+
}
149+
}
150+
151+
size_t max_rank = std::max(input_rank, output_rank);
152+
153+
// Dimensions are right alignment. Eg: an input of [3, 1] and max_rank = 4, the result of concat is [1, 1, 3, 1]
154+
auto new_input_shape_tensor = concat(max_rank, input_rank, ctx, in);
155+
auto new_output_shape_tensor = expandedDimsTensor;
156+
157+
// Add a reshape layer to expand dims
158+
auto shuffle = ctx->net->addShuffle(*in);
159+
shuffle->setInput(1, *new_input_shape_tensor);
160+
161+
// Start the slicing from beginning of tensor since this is an expand layer
162+
std::vector<int64_t> start_vec(max_rank, 0);
163+
nvinfer1::Dims starts_dim = util::toDims(c10::IntArrayRef(start_vec));
164+
at::Tensor thStart = torch::tensor(util::toVec(starts_dim), torch::kInt32);
165+
auto starts = tensor_to_const(ctx, thStart);
166+
167+
// compute sizes = max(x,y).
168+
auto sizes =
169+
ctx->net->addElementWise(*new_input_shape_tensor, *new_output_shape_tensor, nvinfer1::ElementWiseOperation::kMAX)
170+
->getOutput(0);
171+
nvinfer1::Dims sizes_dim{-1, {}};
172+
sizes_dim.nbDims = max_rank;
173+
174+
// Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
175+
// min(1, sub(input_shape, 1))
176+
torch::Tensor thOne = torch::tensor({1}, torch::kInt32);
177+
auto one_tensor = tensor_to_const(ctx, thOne);
178+
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_tensor, nvinfer1::ElementWiseOperation::kSUB)
179+
->getOutput(0);
180+
auto strides = ctx->net->addElementWise(*one_tensor, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
181+
nvinfer1::Dims strides_dim{-1, {}};
182+
strides_dim.nbDims = max_rank;
183+
184+
// Slice layer does the expansion in TRT. Desired output size is specified by sizes input at index 2.
185+
auto slice = ctx->net->addSlice(*shuffle->getOutput(0), starts_dim, sizes_dim, strides_dim);
186+
slice->setInput(1, *starts);
187+
slice->setInput(2, *sizes);
188+
slice->setInput(3, *strides);
189+
190+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice->getOutput(0));
191+
192+
LOG_DEBUG("Expand layer output tensor shape: " << out_tensor->getDimensions());
193+
194+
return true;
195+
}
196+
79197
auto expand_registrations TRTORCH_UNUSED =
80198
RegisterNodeConversionPatterns()
81-
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
82-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
83-
auto in = args[0].ITensor();
84-
auto input_dims = in->getDimensions();
85-
auto expanded_size = args[1].unwrapToIntList();
86-
auto expandedDims = util::toDims(expanded_size);
87-
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
88-
return add_expand(ctx, n, in, expandedDims);
89-
}})
90-
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
91-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92-
// TODO: Currently expand supports static shapes. Need to explore if the same code can be extended
93-
// to dynamic expansion.
94-
auto in = args[0].ITensor();
95-
auto input_dims = in->getDimensions();
96-
auto targetTensor = args[1].ITensor();
97-
auto targetDims = targetTensor->getDimensions();
98-
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
99-
return add_expand(ctx, n, in, targetDims);
100-
}})
101-
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
102-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103-
auto in = args[0].ITensor();
104-
auto input_dims = in->getDimensions();
105-
auto repeats = args[1].unwrapToIntList().vec();
106-
TRTORCH_CHECK(
107-
static_cast<int64_t>(repeats.size()) >= input_dims.nbDims,
108-
"Number of repeat dimensions cannot be smaller than number of input dimensions");
109-
auto num_expand_dims = repeats.size() - input_dims.nbDims;
110-
if (num_expand_dims > 0) {
111-
nvinfer1::Dims reshape_dims;
112-
reshape_dims.nbDims = repeats.size();
113-
for (size_t i = 0; i < num_expand_dims; i++) {
114-
reshape_dims.d[i] = 1;
115-
}
116-
for (int64_t i = 0; i < input_dims.nbDims; i++) {
117-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
118-
}
119-
// Add a reshape layer to expand dims
120-
auto reshape_layer = ctx->net->addShuffle(*in);
121-
reshape_layer->setReshapeDimensions(reshape_dims);
122-
in = reshape_layer->getOutput(0);
123-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
124-
}
125-
126-
LOG_DEBUG("Repeats: " << repeats);
127-
128-
// Concat across all repeat axes.
129-
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
130-
for (int64_t i = repeats.size() - 1; i >= 0; --i) {
131-
std::vector<nvinfer1::ITensor*> tensors_vec;
132-
for (int64_t j = 0; j < repeats[i]; j++) {
133-
tensors_vec.push_back(in);
134-
}
135-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
136-
concat_layer->setAxis(i);
137-
in = concat_layer->getOutput(0);
138-
}
139-
140-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
141-
142-
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
143-
144-
return true;
145-
}});
199+
.pattern(
200+
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
201+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202+
auto in = args[0].ITensor();
203+
auto input_dims = in->getDimensions();
204+
auto expanded_size = args[1].unwrapToIntList();
205+
auto expandedDims = util::toDims(expanded_size);
206+
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
207+
if (ctx->input_is_dynamic) {
208+
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
209+
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
210+
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
211+
} else {
212+
return add_expand(ctx, n, in, expandedDims);
213+
}
214+
}})
215+
.pattern(
216+
{"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
217+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
218+
auto in = args[0].ITensor();
219+
auto input_dims = in->getDimensions();
220+
auto targetTensor = args[1].ITensor();
221+
auto targetDims = targetTensor->getDimensions();
222+
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
223+
if (ctx->input_is_dynamic) {
224+
return add_expand_dynamic(
225+
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
226+
} else {
227+
return add_expand(ctx, n, in, targetDims);
228+
}
229+
}})
230+
.pattern(
231+
{"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
232+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233+
auto in = args[0].ITensor();
234+
auto input_dims = in->getDimensions();
235+
auto repeats = args[1].unwrapToIntList().vec();
236+
int repeats_rank = repeats.size();
237+
TRTORCH_CHECK(
238+
repeats_rank >= input_dims.nbDims,
239+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
240+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
241+
242+
if (ctx->input_is_dynamic) {
243+
int input_rank = input_dims.nbDims;
244+
int output_rank = repeats_rank;
245+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
246+
247+
// Add a reshape layer to expand dims
248+
auto shuffle = ctx->net->addShuffle(*in);
249+
shuffle->setInput(1, *new_input_shape_tensor);
250+
in = shuffle->getOutput(0);
251+
} else {
252+
if (num_expand_dims > 0) {
253+
nvinfer1::Dims reshape_dims;
254+
reshape_dims.nbDims = repeats.size();
255+
for (int i = 0; i < num_expand_dims; i++) {
256+
reshape_dims.d[i] = 1;
257+
}
258+
for (int i = 0; i < input_dims.nbDims; i++) {
259+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
260+
}
261+
// Add a reshape layer to expand dims
262+
auto reshape_layer = ctx->net->addShuffle(*in);
263+
reshape_layer->setReshapeDimensions(reshape_dims);
264+
in = reshape_layer->getOutput(0);
265+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
266+
}
267+
LOG_DEBUG("Repeats: " << repeats);
268+
}
269+
270+
// Concat across all repeat axes.
271+
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
272+
for (int i = repeats.size() - 1; i >= 0; --i) {
273+
std::vector<nvinfer1::ITensor*> tensors_vec;
274+
for (int j = 0; j < repeats[i]; j++) {
275+
tensors_vec.push_back(in);
276+
}
277+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
278+
concat_layer->setAxis(i);
279+
in = concat_layer->getOutput(0);
280+
}
281+
282+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
283+
284+
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
285+
return true;
286+
}});
146287

147288
} // namespace
148289
} // namespace impl
149290
} // namespace converters
150291
} // namespace conversion
151292
} // namespace core
152-
} // namespace trtorch
293+
} // namespace trtorch

core/util/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ cc_library(
8484
})
8585
)
8686

87-
8887
load("@rules_pkg//:pkg.bzl", "pkg_tar")
8988

9089
pkg_tar(

0 commit comments

Comments
 (0)