Skip to content

Commit aff4492

Browse files
committed
refactor(aten::flatten): apply linting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8193055 commit aff4492

File tree

4 files changed

+106
-107
lines changed

4 files changed

+106
-107
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 80 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -196,94 +196,91 @@ bool add_expand_dynamic(
196196

197197
auto expand_registrations TRTORCH_UNUSED =
198198
RegisterNodeConversionPatterns()
199-
.pattern(
200-
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
201-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202-
auto in = args[0].ITensor();
203-
auto input_dims = in->getDimensions();
204-
auto expanded_size = args[1].unwrapToIntList();
205-
auto expandedDims = util::toDims(expanded_size);
206-
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
207-
if (ctx->input_is_dynamic) {
208-
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
209-
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
210-
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
211-
} else {
212-
return add_expand(ctx, n, in, expandedDims);
213-
}
214-
}})
215-
.pattern(
216-
{"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
217-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
218-
auto in = args[0].ITensor();
219-
auto input_dims = in->getDimensions();
220-
auto targetTensor = args[1].ITensor();
221-
auto targetDims = targetTensor->getDimensions();
222-
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
223-
if (ctx->input_is_dynamic) {
224-
return add_expand_dynamic(
225-
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
226-
} else {
227-
return add_expand(ctx, n, in, targetDims);
228-
}
229-
}})
230-
.pattern(
231-
{"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
232-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233-
auto in = args[0].ITensor();
234-
auto input_dims = in->getDimensions();
235-
auto repeats = args[1].unwrapToIntList().vec();
236-
int repeats_rank = repeats.size();
237-
TRTORCH_CHECK(
238-
repeats_rank >= input_dims.nbDims,
239-
"Number of repeat dimensions cannot be smaller than number of input dimensions");
240-
auto num_expand_dims = repeats_rank - input_dims.nbDims;
199+
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
200+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
201+
auto in = args[0].ITensor();
202+
auto input_dims = in->getDimensions();
203+
auto expanded_size = args[1].unwrapToIntList();
204+
auto expandedDims = util::toDims(expanded_size);
205+
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
206+
if (ctx->input_is_dynamic) {
207+
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
208+
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
209+
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
210+
} else {
211+
return add_expand(ctx, n, in, expandedDims);
212+
}
213+
}})
214+
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
215+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
216+
auto in = args[0].ITensor();
217+
auto input_dims = in->getDimensions();
218+
auto targetTensor = args[1].ITensor();
219+
auto targetDims = targetTensor->getDimensions();
220+
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
221+
if (ctx->input_is_dynamic) {
222+
return add_expand_dynamic(
223+
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
224+
} else {
225+
return add_expand(ctx, n, in, targetDims);
226+
}
227+
}})
228+
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
229+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
230+
auto in = args[0].ITensor();
231+
auto input_dims = in->getDimensions();
232+
auto repeats = args[1].unwrapToIntList().vec();
233+
int repeats_rank = repeats.size();
234+
TRTORCH_CHECK(
235+
repeats_rank >= input_dims.nbDims,
236+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
237+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
241238

242-
if (ctx->input_is_dynamic) {
243-
int input_rank = input_dims.nbDims;
244-
int output_rank = repeats_rank;
245-
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
239+
if (ctx->input_is_dynamic) {
240+
int input_rank = input_dims.nbDims;
241+
int output_rank = repeats_rank;
242+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
246243

247-
// Add a reshape layer to expand dims
248-
auto shuffle = ctx->net->addShuffle(*in);
249-
shuffle->setInput(1, *new_input_shape_tensor);
250-
in = shuffle->getOutput(0);
251-
} else {
252-
if (num_expand_dims > 0) {
253-
nvinfer1::Dims reshape_dims;
254-
reshape_dims.nbDims = repeats.size();
255-
for (int i = 0; i < num_expand_dims; i++) {
256-
reshape_dims.d[i] = 1;
257-
}
258-
for (int i = 0; i < input_dims.nbDims; i++) {
259-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
260-
}
261-
// Add a reshape layer to expand dims
262-
auto reshape_layer = ctx->net->addShuffle(*in);
263-
reshape_layer->setReshapeDimensions(reshape_dims);
264-
in = reshape_layer->getOutput(0);
265-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
266-
}
267-
LOG_DEBUG("Repeats: " << repeats);
268-
}
244+
// Add a reshape layer to expand dims
245+
auto shuffle = ctx->net->addShuffle(*in);
246+
shuffle->setInput(1, *new_input_shape_tensor);
247+
in = shuffle->getOutput(0);
248+
} else {
249+
if (num_expand_dims > 0) {
250+
nvinfer1::Dims reshape_dims;
251+
reshape_dims.nbDims = repeats.size();
252+
for (int i = 0; i < num_expand_dims; i++) {
253+
reshape_dims.d[i] = 1;
254+
}
255+
for (int i = 0; i < input_dims.nbDims; i++) {
256+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
257+
}
258+
// Add a reshape layer to expand dims
259+
auto reshape_layer = ctx->net->addShuffle(*in);
260+
reshape_layer->setReshapeDimensions(reshape_dims);
261+
in = reshape_layer->getOutput(0);
262+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
263+
}
264+
LOG_DEBUG("Repeats: " << repeats);
265+
}
269266

270-
// Concat across all repeat axes.
271-
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
272-
for (int i = repeats.size() - 1; i >= 0; --i) {
273-
std::vector<nvinfer1::ITensor*> tensors_vec;
274-
for (int j = 0; j < repeats[i]; j++) {
275-
tensors_vec.push_back(in);
276-
}
277-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
278-
concat_layer->setAxis(i);
279-
in = concat_layer->getOutput(0);
280-
}
267+
// Concat across all repeat axes.
268+
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
269+
for (int i = repeats.size() - 1; i >= 0; --i) {
270+
std::vector<nvinfer1::ITensor*> tensors_vec;
271+
for (int j = 0; j < repeats[i]; j++) {
272+
tensors_vec.push_back(in);
273+
}
274+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
275+
concat_layer->setAxis(i);
276+
in = concat_layer->getOutput(0);
277+
}
281278

282-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
279+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
283280

284-
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
285-
return true;
286-
}});
281+
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
282+
return true;
283+
}});
287284

288285
} // namespace
289286
} // namespace impl

core/conversion/converters/impl/shuffle.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,33 @@ namespace {
1111

1212
static auto shuffle_registrations TRTORCH_UNUSED =
1313
RegisterNodeConversionPatterns()
14-
.pattern({"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
15-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16-
auto in = args[0].ITensorOrFreeze(ctx);
17-
auto start_dim = args[1].unwrapToInt();
18-
auto end_dim = args[2].unwrapToInt();
19-
auto in_shape = util::toVec(in->getDimensions());
20-
std::vector<int64_t> out_shape;
21-
if (ctx->input_is_dynamic && in_shape[0] != -1) {
22-
out_shape = std::vector<int64_t>({in_shape[0], -1});
23-
} else if (ctx->input_is_dynamic && in_shape[0] == -1) {
24-
out_shape = std::vector<int64_t>({-1, -1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies<int64_t>())});
25-
} else {
26-
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
27-
}
14+
.pattern(
15+
{"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
16+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17+
auto in = args[0].ITensorOrFreeze(ctx);
18+
auto start_dim = args[1].unwrapToInt();
19+
auto end_dim = args[2].unwrapToInt();
20+
auto in_shape = util::toVec(in->getDimensions());
21+
std::vector<int64_t> out_shape;
22+
if (ctx->input_is_dynamic && in_shape[0] != -1) {
23+
out_shape = std::vector<int64_t>({in_shape[0], -1});
24+
} else if (ctx->input_is_dynamic && in_shape[0] == -1) {
25+
out_shape = std::vector<int64_t>(
26+
{-1,
27+
-1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies<int64_t>())});
28+
} else {
29+
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
30+
}
2831

29-
auto shuffle = ctx->net->addShuffle(*in);
30-
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
31-
shuffle->setReshapeDimensions(util::toDims(out_shape));
32-
shuffle->setName(util::node_info(n).c_str());
32+
auto shuffle = ctx->net->addShuffle(*in);
33+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
34+
shuffle->setReshapeDimensions(util::toDims(out_shape));
35+
shuffle->setName(util::node_info(n).c_str());
3336

34-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
35-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
36-
return true;
37-
}})
37+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
38+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
39+
return true;
40+
}})
3841
.pattern({"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
3942
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
4043
auto in = args[0].ITensorOrFreeze(ctx);

tests/core/conversion/converters/test_shuffle.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) {
192192
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
193193
}
194194

195-
196195
TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
197196
const auto graph = R"IR(
198197
graph(%0 : Tensor):

tests/util/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ std::vector<at::Tensor> RunGraphEngineDynamic(
3636
std::shared_ptr<torch::jit::Graph>& g,
3737
core::conversion::GraphParams& named_params,
3838
std::vector<at::Tensor> inputs,
39-
bool dynamic_batch=false);
39+
bool dynamic_batch = false);
4040

4141
// Run the forward method of a module and return results
4242
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);

0 commit comments

Comments
 (0)