Skip to content

Commit 0aea728

Browse files
authored
Merge pull request #421 from guoruoqian/pixel_shuffle
Support pixel_shuffle converter
2 parents c302041 + b784638 commit 0aea728

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,81 @@ static auto shuffle_registrations TRTORCH_UNUSED =
125125
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
126126
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
127127

128+
return true;
129+
}})
130+
.pattern({"aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)",
131+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
132+
auto self = args[0].ITensorOrFreeze(ctx);
133+
auto in_shape = util::toVec(self->getDimensions());
134+
int64_t irank = in_shape.size();
135+
TRTORCH_CHECK(
136+
irank >= 3,
137+
"pixel_shuffle expects input to have at least 3 dimensions, but got input with "
138+
<< irank << " dimension(s)");
139+
int64_t upscale_factor = args[1].unwrapToInt();
140+
TRTORCH_CHECK(
141+
upscale_factor > 0,
142+
"pixel_shuffle expects a positive upscale_factor, but got " << upscale_factor);
143+
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
144+
145+
const auto NUM_NON_BATCH_DIMS = 3;
146+
const auto self_sizes_batch_end = in_shape.end() - NUM_NON_BATCH_DIMS;
147+
148+
int64_t ic = in_shape[irank - 3];
149+
int64_t ih = in_shape[irank - 2];
150+
int64_t iw = in_shape[irank - 1];
151+
152+
TRTORCH_CHECK(
153+
ic % upscale_factor_squared == 0,
154+
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
155+
<< "upscale_factor, but input.size(-3)=" << ic << " is not divisible by "
156+
<< upscale_factor_squared);
157+
158+
int64_t oc = ic / upscale_factor_squared;
159+
int64_t oh = ih * upscale_factor;
160+
int64_t ow = iw * upscale_factor;
161+
162+
// First, reshape to split the channels dim from c into 3 separate dims: (oc,
163+
// upscale_factor, upscale_factor). This allows shuffling to be done next by
164+
// permuting dims.
165+
std::vector<int64_t> added_dims_shape(in_shape.begin(), self_sizes_batch_end);
166+
added_dims_shape.insert(added_dims_shape.end(), {oc, upscale_factor, upscale_factor, ih, iw});
167+
auto view_layer = ctx->net->addShuffle(*self);
168+
TRTORCH_CHECK(view_layer, "Unable to create shuffle layer from node: " << *n);
169+
view_layer->setReshapeDimensions(util::toDims(added_dims_shape));
170+
int64_t view_rank = added_dims_shape.size();
171+
172+
// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
173+
auto permutation_layer = ctx->net->addShuffle(*view_layer->getOutput(0));
174+
TRTORCH_CHECK(permutation_layer, "Unable to create shuffle layer from node: " << *n);
175+
// std::iota is used to maintain the batch dims within the permutation.
176+
// Eg: if added_dims_shape is {n1, n2, c, r, r, h, w}, then the new_order is {view_rank-7,
177+
// view_rank-6, view_rank-5, view_rank-2, view_rank-4, view_rank-1, view_rank-3}
178+
std::vector<int64_t> new_order(in_shape.begin(), self_sizes_batch_end);
179+
std::iota(new_order.begin(), new_order.end(), 0);
180+
new_order.insert(
181+
new_order.end(),
182+
{view_rank - 5 /* oc */,
183+
view_rank - 2 /* ih */,
184+
view_rank - 4 /* 1st upscale_factor */,
185+
view_rank - 1 /* iw */,
186+
view_rank - 3 /* 2nd upscale_factor */});
187+
nvinfer1::Permutation permute;
188+
std::copy(new_order.begin(), new_order.end(), permute.order);
189+
permutation_layer->setSecondTranspose(permute);
190+
191+
// Finally, upscale by collapsing (ih, upscale_factor) -> a single dim (oh)
192+
// and (iw, upscale_factor) -> a single dim (ow).
193+
std::vector<int64_t> final_shape(in_shape.begin(), self_sizes_batch_end);
194+
final_shape.insert(final_shape.end(), {oc, oh, ow});
195+
auto last_view_layer = ctx->net->addShuffle(*permutation_layer->getOutput(0));
196+
TRTORCH_CHECK(last_view_layer, "Unable to create shuffle layer from node: " << *n);
197+
last_view_layer->setReshapeDimensions(util::toDims(final_shape));
198+
last_view_layer->setName(util::node_info(n).c_str());
199+
200+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], last_view_layer->getOutput(0));
201+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
202+
128203
return true;
129204
}});
130205

tests/core/conversion/converters/test_shuffle.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,78 @@ TEST(Converters, ATenTransposeNegativeConvertsCorrectly) {
266266

267267
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
268268
}
269+
270+
TEST(Converters, ATenPixelShuffleConvertsCorrectly) {
271+
const auto graph = R"IR(
272+
graph(%x.1 : Tensor):
273+
%2 : int = prim::Constant[value=3]()
274+
%3 : Tensor = aten::pixel_shuffle(%x.1, %2)
275+
return (%3))IR";
276+
277+
auto g = std::make_shared<torch::jit::Graph>();
278+
torch::jit::parseIR(graph, g.get());
279+
280+
auto in = at::randint(0, 5, {1, 9, 4, 5}, {at::kCUDA});
281+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
282+
283+
std::cout << "Running JIT" << std::endl;
284+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
285+
286+
std::cout << "Running TRT" << std::endl;
287+
in = at::clone(in);
288+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
289+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
290+
// auto trt = trt_results[0].reshape_as(jit_results[0]);
291+
292+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
293+
}
294+
295+
TEST(Converters, ATenPixelShuffle3DConvertsCorrectly) {
296+
const auto graph = R"IR(
297+
graph(%x.1 : Tensor):
298+
%2 : int = prim::Constant[value=3]()
299+
%3 : Tensor = aten::pixel_shuffle(%x.1, %2)
300+
return (%3))IR";
301+
302+
auto g = std::make_shared<torch::jit::Graph>();
303+
torch::jit::parseIR(graph, g.get());
304+
305+
auto in = at::randint(0, 5, {9, 4, 5}, {at::kCUDA});
306+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
307+
308+
std::cout << "Running JIT" << std::endl;
309+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
310+
311+
std::cout << "Running TRT" << std::endl;
312+
in = at::clone(in);
313+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
314+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
315+
// auto trt = trt_results[0].reshape_as(jit_results[0]);
316+
317+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
318+
}
319+
320+
TEST(Converters, ATenPixelShuffle5DConvertsCorrectly) {
321+
const auto graph = R"IR(
322+
graph(%x.1 : Tensor):
323+
%2 : int = prim::Constant[value=3]()
324+
%3 : Tensor = aten::pixel_shuffle(%x.1, %2)
325+
return (%3))IR";
326+
327+
auto g = std::make_shared<torch::jit::Graph>();
328+
torch::jit::parseIR(graph, g.get());
329+
330+
auto in = at::randint(0, 5, {2, 3, 9, 4, 5}, {at::kCUDA});
331+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
332+
333+
std::cout << "Running JIT" << std::endl;
334+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
335+
336+
std::cout << "Running TRT" << std::endl;
337+
in = at::clone(in);
338+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
339+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
340+
// auto trt = trt_results[0].reshape_as(jit_results[0]);
341+
342+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
343+
}

0 commit comments

Comments
 (0)