Skip to content

Commit ca3ee6d

Browse files
authored
Merge pull request #236 from NVIDIA/dynamic_batch
fix(aten::flatten): Fixing flatten converter to handle dynamic batch
2 parents b1e59ea + aff4492 commit ca3ee6d

File tree

6 files changed

+151
-114
lines changed

6 files changed

+151
-114
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 80 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -196,94 +196,91 @@ bool add_expand_dynamic(
196196

197197
auto expand_registrations TRTORCH_UNUSED =
198198
RegisterNodeConversionPatterns()
199-
.pattern(
200-
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
201-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202-
auto in = args[0].ITensor();
203-
auto input_dims = in->getDimensions();
204-
auto expanded_size = args[1].unwrapToIntList();
205-
auto expandedDims = util::toDims(expanded_size);
206-
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
207-
if (ctx->input_is_dynamic) {
208-
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
209-
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
210-
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
211-
} else {
212-
return add_expand(ctx, n, in, expandedDims);
213-
}
214-
}})
215-
.pattern(
216-
{"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
217-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
218-
auto in = args[0].ITensor();
219-
auto input_dims = in->getDimensions();
220-
auto targetTensor = args[1].ITensor();
221-
auto targetDims = targetTensor->getDimensions();
222-
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
223-
if (ctx->input_is_dynamic) {
224-
return add_expand_dynamic(
225-
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
226-
} else {
227-
return add_expand(ctx, n, in, targetDims);
228-
}
229-
}})
230-
.pattern(
231-
{"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
232-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233-
auto in = args[0].ITensor();
234-
auto input_dims = in->getDimensions();
235-
auto repeats = args[1].unwrapToIntList().vec();
236-
int repeats_rank = repeats.size();
237-
TRTORCH_CHECK(
238-
repeats_rank >= input_dims.nbDims,
239-
"Number of repeat dimensions cannot be smaller than number of input dimensions");
240-
auto num_expand_dims = repeats_rank - input_dims.nbDims;
199+
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
200+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
201+
auto in = args[0].ITensor();
202+
auto input_dims = in->getDimensions();
203+
auto expanded_size = args[1].unwrapToIntList();
204+
auto expandedDims = util::toDims(expanded_size);
205+
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
206+
if (ctx->input_is_dynamic) {
207+
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
208+
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
209+
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
210+
} else {
211+
return add_expand(ctx, n, in, expandedDims);
212+
}
213+
}})
214+
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
215+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
216+
auto in = args[0].ITensor();
217+
auto input_dims = in->getDimensions();
218+
auto targetTensor = args[1].ITensor();
219+
auto targetDims = targetTensor->getDimensions();
220+
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
221+
if (ctx->input_is_dynamic) {
222+
return add_expand_dynamic(
223+
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
224+
} else {
225+
return add_expand(ctx, n, in, targetDims);
226+
}
227+
}})
228+
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
229+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
230+
auto in = args[0].ITensor();
231+
auto input_dims = in->getDimensions();
232+
auto repeats = args[1].unwrapToIntList().vec();
233+
int repeats_rank = repeats.size();
234+
TRTORCH_CHECK(
235+
repeats_rank >= input_dims.nbDims,
236+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
237+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
241238

242-
if (ctx->input_is_dynamic) {
243-
int input_rank = input_dims.nbDims;
244-
int output_rank = repeats_rank;
245-
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
239+
if (ctx->input_is_dynamic) {
240+
int input_rank = input_dims.nbDims;
241+
int output_rank = repeats_rank;
242+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
246243

247-
// Add a reshape layer to expand dims
248-
auto shuffle = ctx->net->addShuffle(*in);
249-
shuffle->setInput(1, *new_input_shape_tensor);
250-
in = shuffle->getOutput(0);
251-
} else {
252-
if (num_expand_dims > 0) {
253-
nvinfer1::Dims reshape_dims;
254-
reshape_dims.nbDims = repeats.size();
255-
for (int i = 0; i < num_expand_dims; i++) {
256-
reshape_dims.d[i] = 1;
257-
}
258-
for (int i = 0; i < input_dims.nbDims; i++) {
259-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
260-
}
261-
// Add a reshape layer to expand dims
262-
auto reshape_layer = ctx->net->addShuffle(*in);
263-
reshape_layer->setReshapeDimensions(reshape_dims);
264-
in = reshape_layer->getOutput(0);
265-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
266-
}
267-
LOG_DEBUG("Repeats: " << repeats);
268-
}
244+
// Add a reshape layer to expand dims
245+
auto shuffle = ctx->net->addShuffle(*in);
246+
shuffle->setInput(1, *new_input_shape_tensor);
247+
in = shuffle->getOutput(0);
248+
} else {
249+
if (num_expand_dims > 0) {
250+
nvinfer1::Dims reshape_dims;
251+
reshape_dims.nbDims = repeats.size();
252+
for (int i = 0; i < num_expand_dims; i++) {
253+
reshape_dims.d[i] = 1;
254+
}
255+
for (int i = 0; i < input_dims.nbDims; i++) {
256+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
257+
}
258+
// Add a reshape layer to expand dims
259+
auto reshape_layer = ctx->net->addShuffle(*in);
260+
reshape_layer->setReshapeDimensions(reshape_dims);
261+
in = reshape_layer->getOutput(0);
262+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
263+
}
264+
LOG_DEBUG("Repeats: " << repeats);
265+
}
269266

270-
// Concat across all repeat axes.
271-
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
272-
for (int i = repeats.size() - 1; i >= 0; --i) {
273-
std::vector<nvinfer1::ITensor*> tensors_vec;
274-
for (int j = 0; j < repeats[i]; j++) {
275-
tensors_vec.push_back(in);
276-
}
277-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
278-
concat_layer->setAxis(i);
279-
in = concat_layer->getOutput(0);
280-
}
267+
// Concat across all repeat axes.
268+
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
269+
for (int i = repeats.size() - 1; i >= 0; --i) {
270+
std::vector<nvinfer1::ITensor*> tensors_vec;
271+
for (int j = 0; j < repeats[i]; j++) {
272+
tensors_vec.push_back(in);
273+
}
274+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
275+
concat_layer->setAxis(i);
276+
in = concat_layer->getOutput(0);
277+
}
281278

282-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
279+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
283280

284-
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
285-
return true;
286-
}});
281+
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
282+
return true;
283+
}});
287284

288285
} // namespace
289286
} // namespace impl

core/conversion/converters/impl/shuffle.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,33 @@ namespace {
1111

1212
static auto shuffle_registrations TRTORCH_UNUSED =
1313
RegisterNodeConversionPatterns()
14-
.pattern({"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
15-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16-
auto in = args[0].ITensorOrFreeze(ctx);
17-
auto start_dim = args[1].unwrapToInt();
18-
auto end_dim = args[2].unwrapToInt();
19-
auto in_shape = util::toVec(in->getDimensions());
20-
std::vector<int64_t> out_shape;
21-
if (ctx->input_is_dynamic) {
22-
out_shape = std::vector<int64_t>({in_shape[0], -1});
23-
} else {
24-
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
25-
}
14+
.pattern(
15+
{"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
16+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17+
auto in = args[0].ITensorOrFreeze(ctx);
18+
auto start_dim = args[1].unwrapToInt();
19+
auto end_dim = args[2].unwrapToInt();
20+
auto in_shape = util::toVec(in->getDimensions());
21+
std::vector<int64_t> out_shape;
22+
if (ctx->input_is_dynamic && in_shape[0] != -1) {
23+
out_shape = std::vector<int64_t>({in_shape[0], -1});
24+
} else if (ctx->input_is_dynamic && in_shape[0] == -1) {
25+
out_shape = std::vector<int64_t>(
26+
{-1,
27+
-1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies<int64_t>())});
28+
} else {
29+
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
30+
}
2631

27-
auto shuffle = ctx->net->addShuffle(*in);
28-
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
29-
shuffle->setReshapeDimensions(util::toDims(out_shape));
30-
shuffle->setName(util::node_info(n).c_str());
32+
auto shuffle = ctx->net->addShuffle(*in);
33+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
34+
shuffle->setReshapeDimensions(util::toDims(out_shape));
35+
shuffle->setName(util::node_info(n).c_str());
3136

32-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
33-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
34-
return true;
35-
}})
37+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
38+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
39+
return true;
40+
}})
3641
.pattern({"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
3742
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3843
auto in = args[0].ITensorOrFreeze(ctx);

tests/core/conversion/converters/test_pooling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
402402

403403
auto trt_in = at::clone(in);
404404
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
405-
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
405+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false);
406406

407407
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
408408
}

tests/core/conversion/converters/test_shuffle.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,30 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) {
186186

187187
in = at::clone(in);
188188
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
189-
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in});
189+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, false);
190+
auto trt = trt_results[0].reshape_as(jit_results[0]);
191+
192+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
193+
}
194+
195+
TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
196+
const auto graph = R"IR(
197+
graph(%0 : Tensor):
198+
%1 : int = prim::Constant[value=0]()
199+
%2 : int = prim::Constant[value=1]()
200+
%3 : Tensor = aten::flatten(%0, %1, %2)
201+
return (%3))IR";
202+
203+
auto g = std::make_shared<torch::jit::Graph>();
204+
torch::jit::parseIR(graph, &*g);
205+
206+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
207+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
208+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
209+
210+
in = at::clone(in);
211+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
212+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
190213
auto trt = trt_results[0].reshape_as(jit_results[0]);
191214

192215
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));

tests/util/run_graph_engine.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,29 @@ std::vector<core::conversion::InputRange> toInputRanges(std::vector<at::Tensor>
2323
return std::move(a);
2424
}
2525

26-
std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten) {
26+
std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::Tensor> ten, bool dynamic_batch) {
2727
std::vector<core::conversion::InputRange> a;
2828

2929
for (auto i : ten) {
3030
auto opt = core::util::toVec(i.sizes());
3131

32-
std::vector<int64_t> min_range(opt);
33-
std::vector<int64_t> max_range(opt);
32+
if (dynamic_batch) {
33+
std::vector<int64_t> min_range(opt);
34+
std::vector<int64_t> max_range(opt);
3435

35-
min_range[1] = ceil(opt[1] / 2.0);
36-
max_range[1] = 2 * opt[1];
36+
min_range[0] = ceil(opt[0] / 2.0);
37+
max_range[0] = 2 * opt[0];
3738

38-
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
39+
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
40+
} else {
41+
std::vector<int64_t> min_range(opt);
42+
std::vector<int64_t> max_range(opt);
43+
44+
min_range[1] = ceil(opt[1] / 2.0);
45+
max_range[1] = 2 * opt[1];
46+
47+
a.push_back(core::conversion::InputRange(min_range, opt, max_range));
48+
}
3949
}
4050

4151
return std::move(a);
@@ -63,9 +73,10 @@ std::vector<at::Tensor> RunGraphEngine(
6373
std::vector<at::Tensor> RunGraphEngineDynamic(
6474
std::shared_ptr<torch::jit::Graph>& g,
6575
core::conversion::GraphParams& named_params,
66-
std::vector<at::Tensor> inputs) {
76+
std::vector<at::Tensor> inputs,
77+
bool dynamic_batch) {
6778
LOG_DEBUG("Running TRT version");
68-
auto in = toInputRangesDynamic(inputs);
79+
auto in = toInputRangesDynamic(inputs, dynamic_batch);
6980
auto info = core::conversion::ConversionInfo(in);
7081
info.engine_settings.workspace_size = 1 << 20;
7182
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);

tests/util/util.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ std::vector<at::Tensor> RunGraphEngine(
3535
std::vector<at::Tensor> RunGraphEngineDynamic(
3636
std::shared_ptr<torch::jit::Graph>& g,
3737
core::conversion::GraphParams& named_params,
38-
std::vector<at::Tensor> inputs);
38+
std::vector<at::Tensor> inputs,
39+
bool dynamic_batch = false);
3940

4041
// Run the forward method of a module and return results
4142
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);

0 commit comments

Comments
 (0)