Skip to content

Commit f531b32

Browse files
committed
Modify code to conform with style guidelines
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 895542c commit f531b32

File tree

4 files changed

+415
-380
lines changed

4 files changed

+415
-380
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 120 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -46,30 +46,30 @@ nvinfer1::ILayer* create_plugin(
4646
void addSliceInput(nvinfer1::Dims& dims, int idx, ConversionCtx* ctx, nvinfer1::ISliceLayer* slice) {
4747
int32_t rank = static_cast<int32_t>(dims.nbDims);
4848
int32_t* tmp = new int32_t[rank];
49-
for(int i=0;i<rank;i++)
49+
for (int i = 0; i < rank; i++)
5050
tmp[i] = dims.d[i];
5151
const nvinfer1::Dims d{1, {rank}};
5252
const nvinfer1::Weights w{nvinfer1::DataType::kINT32, tmp, rank};
5353
auto t = ctx->net->addConstant(d, w)->getOutput(0);
5454
slice->setInput(idx, *t);
5555
}
5656

57-
nvinfer1::ITensor* vec2Tensor(int32_t *dim, int rank, ConversionCtx* ctx){
57+
nvinfer1::ITensor* vec2Tensor(int32_t* dim, int rank, ConversionCtx* ctx) {
5858
const nvinfer1::Dims d{1, {static_cast<int32_t>(rank)}};
5959
const nvinfer1::Weights w{nvinfer1::DataType::kINT32, dim, rank};
6060
return ctx->net->addConstant(d, w)->getOutput(0);
6161
}
6262

63-
nvinfer1::ITensor * concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor*tensor){
64-
if(max_rank - old_rank > 0){
63+
nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor* tensor) {
64+
if (max_rank - old_rank > 0) {
6565
int32_t* tmp = new int32_t[max_rank - old_rank];
66-
for(int i=0;i<(max_rank - old_rank);i++)
66+
for (int i = 0; i < (max_rank - old_rank); i++)
6767
tmp[i] = 1;
6868
auto max_rank_tensor = vec2Tensor(tmp, max_rank - old_rank, ctx);
6969
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
7070
nvinfer1::ITensor* const args[2] = {max_rank_tensor, in_shape_tensor};
7171
return ctx->net->addConcatenation(args, 2)->getOutput(0);
72-
}else{ // max_rank - old_rank == 0
72+
} else { // max_rank - old_rank == 0
7373
return ctx->net->addShape(*tensor)->getOutput(0);
7474
}
7575
}
@@ -86,7 +86,7 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
8686
int64_t dim = input_dims.nbDims - 1 - offset;
8787
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
8888
int64_t targetSize = expandedDims.d[i];
89-
if(targetSize != -1){
89+
if (targetSize != -1) {
9090
if (size != targetSize) {
9191
if (size != 1) {
9292
TRTORCH_THROW_ERROR(
@@ -95,12 +95,12 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
9595
<< " at dimension " << i);
9696
}
9797
}
98-
}else{
99-
if(dim < 0){
100-
TRTORCH_THROW_ERROR("The expanded size of the tensor (" << \
101-
targetSize << ") isn't allowed in a leading, non-existing dimension " << \
102-
i);
103-
}else{
98+
} else {
99+
if (dim < 0) {
100+
TRTORCH_THROW_ERROR(
101+
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
102+
<< i);
103+
} else {
104104
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
105105
expandedDims.d[i] = input_dims.d[dim];
106106
}
@@ -146,18 +146,23 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
146146
return true;
147147
}
148148

149-
bool add_expand_dynamic(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::ITensor* expandedDimsTensor){
150-
auto input_shape_tensor = ctx->net->addShape(*in)->getOutput(0);
149+
bool add_expand_dynamic(
150+
ConversionCtx* ctx,
151+
const torch::jit::Node* n,
152+
nvinfer1::ITensor* in,
153+
nvinfer1::ITensor* expandedDimsTensor) {
154+
auto input_shape_tensor = ctx->net->addShape(*in)->getOutput(0);
151155
auto input_rank = in->getDimensions().nbDims;
152156
auto output_rank = expandedDimsTensor->getDimensions().d[0];
153157
TRTORCH_CHECK(
154158
input_rank <= output_rank,
155159
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
156-
160+
157161
// add a plugin to check expandedDimsTensor whether match input_shape_tensor
158-
auto expandShape_layer = create_plugin(ctx, n, input_shape_tensor, expandedDimsTensor, input_rank, output_rank, "expandShape");
162+
auto expandShape_layer =
163+
create_plugin(ctx, n, input_shape_tensor, expandedDimsTensor, input_rank, output_rank, "expandShape");
159164
auto _tensor = expandShape_layer->getOutput(0);
160-
165+
161166
size_t max_rank = std::max(input_rank, output_rank);
162167

163168
// Dimensions are right alignment
@@ -174,16 +179,19 @@ bool add_expand_dynamic(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1:
174179
nvinfer1::Dims starts_dim = util::toDims(c10::IntArrayRef(start_vec));
175180

176181
// compute sizes = max(x,y).
177-
auto sizes = ctx->net->addElementWise(*new_input_shape_tensor, *new_output_shape_tensor, nvinfer1::ElementWiseOperation::kMAX)->getOutput(0);
182+
auto sizes =
183+
ctx->net->addElementWise(*new_input_shape_tensor, *new_output_shape_tensor, nvinfer1::ElementWiseOperation::kMAX)
184+
->getOutput(0);
178185
nvinfer1::Dims sizes_dim{-1, {}};
179186
sizes_dim.nbDims = max_rank;
180-
187+
181188
// Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
182189
// min(1, sub(input_shape, 1))
183190
int32_t* one_vector_tmp = new int32_t[1];
184191
one_vector_tmp[0] = 1;
185192
auto one_vector = vec2Tensor(one_vector_tmp, 1, ctx);
186-
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_vector, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
193+
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_vector, nvinfer1::ElementWiseOperation::kSUB)
194+
->getOutput(0);
187195
auto strides = ctx->net->addElementWise(*one_vector, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
188196
nvinfer1::Dims strides_dim{-1, {}};
189197
strides_dim.nbDims = max_rank;
@@ -194,7 +202,7 @@ bool add_expand_dynamic(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1:
194202
slice->setInput(2, *sizes);
195203
slice->setInput(3, *strides);
196204

197-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice->getOutput(0));
205+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice->getOutput(0));
198206

199207
LOG_DEBUG("Expand layer output tensor shape: " << out_tensor->getDimensions());
200208

@@ -203,94 +211,96 @@ bool add_expand_dynamic(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1:
203211

204212
auto expand_registrations TRTORCH_UNUSED =
205213
RegisterNodeConversionPatterns()
206-
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
207-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
208-
auto in = args[0].ITensor();
209-
auto input_dims = in->getDimensions();
210-
auto expanded_size = args[1].unwrapToIntList();
211-
auto expandedDims = util::toDims(expanded_size);
212-
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
213-
if(ctx->input_is_dynamic){
214-
int expanded_size_rank = static_cast<int>(expanded_size.size());
215-
int32_t* tmp = new int32_t[expanded_size_rank];
216-
for(int i=0;i<expanded_size_rank;i++)
217-
tmp[i] = expanded_size[i];
218-
auto expandedDimsTensor = vec2Tensor(tmp, expanded_size_rank, ctx);
219-
return add_expand_dynamic(ctx, n, in, expandedDimsTensor);
220-
}else{
221-
return add_expand(ctx, n, in, expandedDims);
222-
}
223-
}})
224-
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
225-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
226-
auto in = args[0].ITensor();
227-
auto input_dims = in->getDimensions();
228-
auto targetTensor = args[1].ITensor();
229-
auto targetDims = targetTensor->getDimensions();
230-
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
231-
if(ctx->input_is_dynamic){
232-
return add_expand_dynamic(ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0));
233-
}else{
234-
return add_expand(ctx, n, in, targetDims);
235-
}
236-
237-
}})
238-
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
239-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
240-
auto in = args[0].ITensor();
241-
auto input_dims = in->getDimensions();
242-
auto repeats = args[1].unwrapToIntList().vec();
243-
int repeats_rank = repeats.size();
244-
TRTORCH_CHECK(
245-
repeats_rank >= input_dims.nbDims,
246-
"Number of repeat dimensions cannot be smaller than number of input dimensions");
247-
auto num_expand_dims = repeats_rank - input_dims.nbDims;
248-
249-
if(ctx->input_is_dynamic){
250-
int input_rank = input_dims.nbDims;
251-
int output_rank= repeats_rank;
252-
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
253-
254-
// Add a reshape layer to expand dims
255-
auto shuffle = ctx->net->addShuffle(*in);
256-
shuffle->setInput(1, *new_input_shape_tensor);
257-
in = shuffle->getOutput(0);
258-
}else{
259-
if (num_expand_dims > 0) {
260-
nvinfer1::Dims reshape_dims;
261-
reshape_dims.nbDims = repeats.size();
262-
for (int i = 0; i < num_expand_dims; i++) {
263-
reshape_dims.d[i] = 1;
264-
}
265-
for (int i = 0; i < input_dims.nbDims; i++) {
266-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
267-
}
268-
// Add a reshape layer to expand dims
269-
auto reshape_layer = ctx->net->addShuffle(*in);
270-
reshape_layer->setReshapeDimensions(reshape_dims);
271-
in = reshape_layer->getOutput(0);
272-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
273-
}
274-
LOG_DEBUG("Repeats: " << repeats);
275-
}
276-
277-
// Concat across all repeat axes.
278-
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
279-
for (int i = repeats.size() - 1; i >= 0; --i) {
280-
std::vector<nvinfer1::ITensor*> tensors_vec;
281-
for (int j = 0; j < repeats[i]; j++) {
282-
tensors_vec.push_back(in);
283-
}
284-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
285-
concat_layer->setAxis(i);
286-
in = concat_layer->getOutput(0);
287-
}
288-
289-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
290-
291-
LOG_DEBUG("Repeat layer output tensor shape: " << in->getDimensions());
292-
return true;
293-
}});
214+
.pattern(
215+
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
216+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
217+
auto in = args[0].ITensor();
218+
auto input_dims = in->getDimensions();
219+
auto expanded_size = args[1].unwrapToIntList();
220+
auto expandedDims = util::toDims(expanded_size);
221+
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
222+
if (ctx->input_is_dynamic) {
223+
int expanded_size_rank = static_cast<int>(expanded_size.size());
224+
int32_t* tmp = new int32_t[expanded_size_rank];
225+
for (int i = 0; i < expanded_size_rank; i++)
226+
tmp[i] = expanded_size[i];
227+
auto expandedDimsTensor = vec2Tensor(tmp, expanded_size_rank, ctx);
228+
return add_expand_dynamic(ctx, n, in, expandedDimsTensor);
229+
} else {
230+
return add_expand(ctx, n, in, expandedDims);
231+
}
232+
}})
233+
.pattern(
234+
{"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
235+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
236+
auto in = args[0].ITensor();
237+
auto input_dims = in->getDimensions();
238+
auto targetTensor = args[1].ITensor();
239+
auto targetDims = targetTensor->getDimensions();
240+
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
241+
if (ctx->input_is_dynamic) {
242+
return add_expand_dynamic(ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0));
243+
} else {
244+
return add_expand(ctx, n, in, targetDims);
245+
}
246+
}})
247+
.pattern(
248+
{"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
249+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
250+
auto in = args[0].ITensor();
251+
auto input_dims = in->getDimensions();
252+
auto repeats = args[1].unwrapToIntList().vec();
253+
int repeats_rank = repeats.size();
254+
TRTORCH_CHECK(
255+
repeats_rank >= input_dims.nbDims,
256+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
257+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
258+
259+
if (ctx->input_is_dynamic) {
260+
int input_rank = input_dims.nbDims;
261+
int output_rank = repeats_rank;
262+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
263+
264+
// Add a reshape layer to expand dims
265+
auto shuffle = ctx->net->addShuffle(*in);
266+
shuffle->setInput(1, *new_input_shape_tensor);
267+
in = shuffle->getOutput(0);
268+
} else {
269+
if (num_expand_dims > 0) {
270+
nvinfer1::Dims reshape_dims;
271+
reshape_dims.nbDims = repeats.size();
272+
for (int i = 0; i < num_expand_dims; i++) {
273+
reshape_dims.d[i] = 1;
274+
}
275+
for (int i = 0; i < input_dims.nbDims; i++) {
276+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
277+
}
278+
// Add a reshape layer to expand dims
279+
auto reshape_layer = ctx->net->addShuffle(*in);
280+
reshape_layer->setReshapeDimensions(reshape_dims);
281+
in = reshape_layer->getOutput(0);
282+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
283+
}
284+
LOG_DEBUG("Repeats: " << repeats);
285+
}
286+
287+
// Concat across all repeat axes.
288+
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
289+
for (int i = repeats.size() - 1; i >= 0; --i) {
290+
std::vector<nvinfer1::ITensor*> tensors_vec;
291+
for (int j = 0; j < repeats[i]; j++) {
292+
tensors_vec.push_back(in);
293+
}
294+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
295+
concat_layer->setAxis(i);
296+
in = concat_layer->getOutput(0);
297+
}
298+
299+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
300+
301+
LOG_DEBUG("Repeat layer output tensor shape: " << in->getDimensions());
302+
return true;
303+
}});
294304

295305
} // namespace
296306
} // namespace impl

0 commit comments

Comments
 (0)