Skip to content

Commit ccf8d2e

Browse files
authored
Merge pull request #986 from NVIDIA/unbind
feat (core//conversion) : Add converter for torch.unbind
2 parents 6b9872d + 91288ae commit ccf8d2e

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,32 @@ namespace converters {
1515
namespace impl {
1616
namespace {
1717

18-
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) {
18+
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
1919
auto in = args[0].ITensor();
20-
auto axis = args[2].unwrapToInt();
21-
auto inDimSize = in->getDimensions().d[axis];
22-
auto numOutputs = 1, numRemainder = 0;
20+
auto numOutputs = 1, numRemainder = 0, axis = 0;
2321
std::vector<int64_t> sizes;
2422

25-
if (split_list) {
26-
sizes = args[1].unwrapToIntList().vec();
27-
numOutputs = sizes.size();
23+
if (unbind) {
24+
axis = args[1].unwrapToInt();
25+
numOutputs = in->getDimensions().d[axis];
26+
sizes.insert(sizes.end(), numOutputs, 1);
2827
} else {
29-
auto split_size = args[1].unwrapToInt();
30-
numOutputs = inDimSize / split_size;
31-
numRemainder = inDimSize % split_size;
32-
for (int64_t i = 0; i < numOutputs; i++) {
33-
sizes.push_back(split_size);
34-
}
35-
if (numRemainder) {
36-
numOutputs += 1;
37-
sizes.push_back(numRemainder);
28+
axis = args[2].unwrapToInt();
29+
auto inDimSize = in->getDimensions().d[axis];
30+
if (split_list) {
31+
sizes = args[1].unwrapToIntList().vec();
32+
numOutputs = sizes.size();
33+
} else {
34+
auto split_size = args[1].unwrapToInt();
35+
numOutputs = inDimSize / split_size;
36+
numRemainder = inDimSize % split_size;
37+
for (int64_t i = 0; i < numOutputs; i++) {
38+
sizes.push_back(split_size);
39+
}
40+
if (numRemainder) {
41+
numOutputs += 1;
42+
sizes.push_back(numRemainder);
43+
}
3844
}
3945
}
4046

@@ -340,19 +346,25 @@ auto select_registrations TORCHTRT_UNUSED =
340346
}})
341347
.pattern({"aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])",
342348
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
343-
add_split(ctx, n, args, true);
349+
add_split(ctx, n, args, true, false);
344350
LOG_DEBUG("Converted split op into a list of IValues");
345351
return true;
346352
}})
347353
.pattern({"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])",
348354
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
349-
add_split(ctx, n, args, false);
355+
add_split(ctx, n, args, false, false);
350356
LOG_DEBUG("Converted split op into a list of IValues");
351357
return true;
352358
}})
353359
.pattern({"aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])",
354360
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
355-
add_split(ctx, n, args, true);
361+
add_split(ctx, n, args, true, false);
362+
LOG_DEBUG("Converted split op into a list of IValues");
363+
return true;
364+
}})
365+
.pattern({"aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])",
366+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
367+
add_split(ctx, n, args, false, true);
356368
LOG_DEBUG("Converted split op into a list of IValues");
357369
return true;
358370
}})

tests/core/conversion/converters/test_select.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,3 +567,30 @@ TEST(Converters, ATenIndexTensorConvertsCorrectly) {
567567
ASSERT_TRUE(
568568
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
569569
}
570+
571+
TEST(Converters, ATenUnbindConvertsCorrectly) {
572+
const auto graph = R"IR(
573+
graph(%x.1 : Tensor):
574+
%2 : int = prim::Constant[value=0]()
575+
%3 : Tensor[] = aten::unbind(%x.1, %2)
576+
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
577+
return (%o1.1, %o2.1))IR";
578+
579+
auto g = std::make_shared<torch::jit::Graph>();
580+
581+
torch::jit::parseIR(graph, g.get());
582+
583+
auto in = at::randint(1, 10, {2, 3, 4, 4}, {at::kCUDA});
584+
585+
auto jit_in = at::clone(in);
586+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
587+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
588+
589+
auto trt_in = at::clone(in);
590+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
591+
592+
for (size_t i = 0; i < jit_results.size(); i++) {
593+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
594+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
595+
}
596+
}

0 commit comments

Comments
 (0)