Skip to content

Commit f0cdcc6

Browse files
committed
Fuse view/tranpose/view as shuffle
1 parent 03e161c commit f0cdcc6

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

tests/cpu/test_jit.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,22 @@ def forward(self, x):
235235
c = torch.add(b, b)
236236
return c
237237

238+
class ChannelShuffle(nn.Module):
239+
def __init__(self, batchsize, num_channels, height, width, groups):
240+
super(ChannelShuffle, self).__init__()
241+
self.batchsize = batchsize
242+
self.num_channels = num_channels
243+
self.height = height
244+
self.width = width
245+
self.groups = groups
246+
247+
def forward(self, x):
248+
channels_per_group = self.num_channels // self.groups
249+
x = x.view(self.batchsize, self.groups, channels_per_group, self.height, self.width)
250+
x = torch.transpose(x, 1, 2).contiguous()
251+
x = x.view(self.batchsize, -1, self.height, self.width)
252+
return x
253+
238254

239255
class Tester(TestCase):
240256

@@ -528,6 +544,13 @@ def test_output_linear_relu(self):
528544
kind_in_graph="ipex::linear_relu")
529545

530546

547+
def test_channel_shuffle(self):
548+
self._test_output(
549+
ChannelShuffle(10, 16, 50, 50, 4),
550+
torch.rand(10, 16, 50, 50),
551+
kind_in_graph="ipex::shuffle_2d")
552+
553+
531554
def test_jit_function(self):
532555
# test hool trace and script can works for function
533556
def fn(input, weight, bias):

torch_ipex/csrc/jit/graph_rewrite.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ std::unordered_map<std::string, c10::IValue> getConvParams(
5555
void FuseShuffle(std::shared_ptr<Graph>& graph) {
5656
std::string shuffle = R"(
5757
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
58-
%r = aten::view(%input, %view_shape)
59-
%r = aten::transpose(%r, %trans_dim0, %trans_dim1)
60-
%r = aten::contiguous(%r, %mem_format)
61-
%r = aten::view(%r, %flattern_shape)
62-
return (%r) )";
58+
%r1 = aten::view(%input, %view_shape)
59+
%r2 = aten::transpose(%r1, %trans_dim0, %trans_dim1)
60+
%r3 = aten::contiguous(%r2, %mem_format)
61+
%r4 = aten::view(%r3, %flattern_shape)
62+
return (%r4) )";
6363

6464
std::string shuffle_2d_fusion = R"(
6565
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):

0 commit comments

Comments
 (0)