Skip to content

Commit 5e1b842

Browse files
authored
Merge pull request #188 from peri044/power_layer
feat(//core/converters): Add conversion support for torch.narrow()
2 parents e854c75 + 682e2f0 commit 5e1b842

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,77 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
4646

4747
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
4848

49+
return true;
50+
}
51+
}).pattern({
52+
"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)",
53+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
54+
auto in = args[0].ITensor();
55+
auto axis = args[1].unwrapToInt();
56+
auto start = (int32_t) args[2].unwrapToInt();
57+
auto length = (int32_t) args[3].unwrapToInt();
58+
59+
// index to access needs to be an at::Tensor
60+
at::Tensor indices = torch::arange(start, start + length, 1).to(torch::kI32);
61+
auto weights = Weights(ctx, indices);
62+
63+
// IConstantLayer to convert indices from Weights to ITensor
64+
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
65+
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
66+
auto const_out = const_layer->getOutput(0);
67+
68+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
69+
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
70+
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
71+
auto gather_out = gather_layer->getOutput(0);
72+
73+
// IShuffleLayer removes redundant dimensions
74+
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
75+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
76+
shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions()));
77+
shuffle_layer->setName(util::node_info(n).c_str());
78+
auto shuffle_out = shuffle_layer->getOutput(0);
79+
80+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out);
81+
82+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
83+
84+
return true;
85+
}
86+
}).pattern({
87+
"aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)",
88+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
89+
auto in = args[0].ITensor();
90+
auto axis = args[1].unwrapToInt();
91+
torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32);
92+
int32_t startIdx = start.item().to<int32_t>();
93+
auto length = (int32_t) args[3].unwrapToInt();
94+
95+
// index to access needs to be an at::Tensor
96+
at::Tensor indices = torch::arange(startIdx, startIdx + length, 1).to(torch::kI32);
97+
auto weights = Weights(ctx, indices);
98+
99+
// IConstantLayer to convert indices from Weights to ITensor
100+
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
101+
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
102+
auto const_out = const_layer->getOutput(0);
103+
104+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
105+
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
106+
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
107+
auto gather_out = gather_layer->getOutput(0);
108+
109+
// IShuffleLayer removes redundant dimensions
110+
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
111+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
112+
shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions()));
113+
shuffle_layer->setName(util::node_info(n).c_str());
114+
auto shuffle_out = shuffle_layer->getOutput(0);
115+
116+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out);
117+
118+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
119+
49120
return true;
50121
}
51122
});

tests/core/converters/test_select.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,32 @@ TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
5555

5656
auto trt = trt_results[0].reshape(jit_results[0].sizes());
5757

58+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
59+
}
60+
61+
TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) {
62+
const auto graph = R"IR(
63+
graph(%x.1 : Tensor):
64+
%2 : int = prim::Constant[value=2]()
65+
%3 : int = prim::Constant[value=0]()
66+
%4 : Tensor = aten::narrow(%x.1, %3, %3, %2)
67+
return (%4))IR";
68+
69+
auto g = std::make_shared<torch::jit::Graph>();
70+
71+
torch::jit::parseIR(graph, &*g);
72+
73+
auto in = at::randint(1, 10, {3, 2, 2, 4}, {at::kCUDA});
74+
75+
auto jit_in = at::clone(in);
76+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
77+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
78+
79+
auto trt_in = at::clone(in);
80+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
81+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
82+
83+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
84+
5885
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
5986
}

0 commit comments

Comments
 (0)