@@ -52,6 +52,96 @@ std::unordered_map<std::string, c10::IValue> getConvParams(
5252 return calc_values;
5353}
5454
55+ void FuseShuffle (std::shared_ptr<Graph>& graph) {
56+ std::string shuffle = R"(
57+ graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
58+ %r = aten::view(%input, %view_shape)
59+ %r = aten::transpose(%r, %trans_dim0, %trans_dim1)
60+ %r = aten::contiguous(%r, %mem_format)
61+ %r = aten::view(%r, %flattern_shape)
62+ return (%r) )" ;
63+
64+ std::string shuffle_2d_fusion = R"(
65+ graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
66+ %r = ipex::shuffle_2d(%input, %view_shape, %trans_dim0, %trans_dim1)
67+ return (%r) )" ;
68+
69+ auto filter_shuffle_2d_fusion = [] (
70+ const Match& match,
71+ const std::unordered_map<std::string, Value*>& vmap) {
72+ const auto & match_vmap = match.values_map ;
73+ auto input_ = getIValue (" input" , match_vmap, vmap).value ();
74+ if (!(input_.isTensor ())) {
75+ return false ;
76+ }
77+ auto view_shape_ = getIValue (" view_shape" , match_vmap, vmap).value ();
78+ if (!(view_shape_.isIntList ())) {
79+ return false ;
80+ }
81+ auto trans_dim0_ = getIValue (" trans_dim0" , match_vmap, vmap).value ();
82+ if (!(trans_dim0_.isInt ())) {
83+ return false ;
84+ }
85+ auto trans_dim1_ = getIValue (" trans_dim1" , match_vmap, vmap).value ();
86+ if (!(trans_dim1_.isInt ())) {
87+ return false ;
88+ }
89+ auto flattern_shape_ = getIValue (" flattern_shape" , match_vmap, vmap).value ();
90+ if (!(flattern_shape_.isInt ())) {
91+ return false ;
92+ }
93+
94+ auto trans_dim0_val = trans_dim0_.toInt ();
95+ auto trans_dim1_val = trans_dim1_.toInt ();
96+ auto dim0_val = trans_dim0_val < trans_dim1_val ? trans_dim0_val : trans_dim1_val;
97+ auto dim1_val = trans_dim0_val > trans_dim1_val ? trans_dim0_val : trans_dim1_val;
98+ // If the tranpose if not for groups. ex. [n, c1, c2, h, w] => [n, c2, c1, h, w]
99+ if ((dim1_val - dim0_val) != 1 ) {
100+ return false ;
101+ }
102+
103+ auto input_val = input_.toTensor ();
104+ auto view_shape_val = view_shape_.toIntVector ();
105+ auto flattern_shape_val = flattern_shape_.toIntVector ();
106+ // ex. [n, c, h, w] => [n, groups, c // groups, h, w]
107+ if ((input_val.ndimension () - view_shape_val.size ()) != -1 ) {
108+ return false ;
109+ }
110+
111+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dim0_val >= 0 );
112+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dim1_val >= 0 );
113+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dim0_val + 1 < input_val.ndimension ());
114+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dim1_val + 1 < input_val.ndimension ());
115+ if (view_shape_val[dim0_val] * view_shape_val[dim1_val] != input_val.size (dim0_val)) {
116+ return false ;
117+ }
118+
119+ if (flattern_shape_val.size () != input_val.ndimension ()) {
120+ return false ;
121+ }
122+
123+ for (int i = 0 ; i < flattern_shape_val.size (); i++) {
124+ if (flattern_shape_val[i] != input_val.size (i)) {
125+ // [n, c, h, w] => view [n, groups, c // groups, h, w] => tranpose [n, c // groups, groups, h, w]
126+ // => view [n, -1, h, w]
127+ // or
128+ // view [n, c, h, w]
129+ if ((flattern_shape_val[i] != -1 ) || (i != dim0_val)) {
130+ return false ;
131+ }
132+ }
133+ }
134+
135+ return true ;
136+ };
137+
138+ SubgraphRewriter rewriter_shuffle_2d;
139+ rewriter_shuffle_2d.RegisterRewritePattern (
140+ shuffle,
141+ shuffle_2d_fusion);
142+ rewriter_shuffle_2d.runOnGraph (graph);
143+ }
144+
55145void FuseConvolutionWithEltwise (std::shared_ptr<Graph>& graph) {
56146 std::string conv2d_swish_fusion = R"(
57147 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
0 commit comments