Skip to content

Commit 6f2166e

Browse files
committed
Add aten.transpose support.
Signed-off-by: Yu-Te Cheng <[email protected]>
1 parent 26d5c65 commit 6f2166e

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,40 @@ static auto shuffle_registrations TRTORCH_UNUSED =
9292
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
9393
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
9494

95+
return true;
96+
}})
97+
.pattern({"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))",
98+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
99+
auto in = args[0].ITensorOrFreeze(ctx);
100+
auto in_shape = util::toVec(in->getDimensions());
101+
auto ndims = in_shape.size();
102+
auto dim0 = args[1].unwrapToInt();
103+
auto dim1 = args[2].unwrapToInt();
104+
105+
std::vector<int64_t> new_order;
106+
for (size_t i = 0; i < ndims; i++) {
107+
new_order.push_back(i);
108+
}
109+
auto tmp = dim0;
110+
new_order[dim0] = new_order[dim1];
111+
new_order[dim1] = tmp;
112+
113+
LOG_DEBUG("Shuffle to: " << util::toDims(new_order));
114+
115+
auto shuffle = ctx->net->addShuffle(*in);
116+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
117+
nvinfer1::Permutation permute;
118+
std::copy(new_order.begin(), new_order.end(), permute.order);
119+
120+
shuffle->setSecondTranspose(permute);
121+
shuffle->setName(util::node_info(n).c_str());
122+
123+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
124+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
125+
95126
return true;
96127
}});
128+
97129
} // namespace
98130
} // namespace impl
99131
} // namespace converters

tests/core/conversion/converters/test_shuffle.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,29 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
214214

215215
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
216216
}
217+
218+
TEST(Converters, ATenTransposeConvertsCorrectly) {
219+
const auto graph = R"IR(
220+
graph(%x.1 : Tensor):
221+
%2 : int = prim::Constant[value=1]()
222+
%3 : int = prim::Constant[value=3]()
223+
%4 : Tensor = aten::transpose(%x.1, %2, %3)
224+
return (%4))IR";
225+
226+
auto g = std::make_shared<torch::jit::Graph>();
227+
torch::jit::parseIR(graph, &*g);
228+
229+
auto in = at::randint(0, 5, {2, 3, 4, 5, 6}, {at::kCUDA});
230+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
231+
232+
std::cout << "Running JIT" << std::endl;
233+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
234+
235+
std::cout << "Running TRT" << std::endl;
236+
in = at::clone(in);
237+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
238+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
239+
auto trt = trt_results[0].reshape_as(jit_results[0]);
240+
241+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
242+
}

0 commit comments

Comments
 (0)