@@ -235,6 +235,22 @@ def forward(self, x):
235235 c = torch .add (b , b )
236236 return c
237237
238+ class ChannelShuffle (nn .Module ):
239+ def __init__ (self , batchsize , num_channels , height , width , groups ):
240+ super (ChannelShuffle , self ).__init__ ()
241+ self .batchsize = batchsize
242+ self .num_channels = num_channels
243+ self .height = height
244+ self .width = width
245+ self .groups = groups
246+
247+ def forward (self , x ):
248+ channels_per_group = self .num_channels // self .groups
249+ x = x .view (self .batchsize , self .groups , channels_per_group , self .height , self .width )
250+ x = torch .transpose (x , 1 , 2 ).contiguous ()
251+ x = x .view (self .batchsize , - 1 , self .height , self .width )
252+ return x
253+
238254
239255class Tester (TestCase ):
240256
@@ -528,6 +544,13 @@ def test_output_linear_relu(self):
528544 kind_in_graph = "ipex::linear_relu" )
529545
530546
547+ def test_channel_shuffle (self ):
548+ self ._test_output (
549+ ChannelShuffle (10 , 16 , 50 , 50 , 4 ),
550+ torch .rand (10 , 16 , 50 , 50 ),
551+ kind_in_graph = "ipex::shuffle_2d" )
552+
553+
531554 def test_jit_function (self ):
532555 # test hool trace and script can works for function
533556 def fn (input , weight , bias ):
0 commit comments