Skip to content

Commit 03e161c

Browse files
committed
Enable shuffle operator
1 parent 363dc39 commit 03e161c

File tree

6 files changed

+133
-2
lines changed

6 files changed

+133
-2
lines changed

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <ATen/NamedTensorUtils.h>
77
#include <c10/util/Exception.h>
88
#include <c10/util/Logging.h>
9+
#include <torch/csrc/autograd/function.h>
10+
#include <torch/csrc/autograd/record_function.h>
911

1012
#include <limits>
1113

@@ -2021,5 +2023,19 @@ at::Tensor AtenIpexCPUDev::dil_index_select(
20212023
return at::Tensor();
20222024
}
20232025

2026+
at::Tensor AtenIpexCPUDev::dil_shuffle(const at::Tensor & self, at::IntArrayRef view_shape, int64_t dim0, int64_t dim1) {
2027+
DEBUG("AtenIpexCPUDev::dil_shuffle\n");
2028+
RECORD_FUNCTION("AtenIpexCPUDev::dil_shuffle", std::vector<c10::IValue>(), -1);
2029+
// NOTE: We do NOT add sanity checks here. Because PyTorch does not has shuffle operator. This dil operator is for fusion and the fusion logic
2030+
// has more sanity checks. We found that there are some models use view + transpose + view to implement shuffle semantic. So IPEX will fuse these
2031+
// operators a single shuffle.
2032+
dil::tensor&& x = dbl::comm::try_gen_dil_tensor(self);
2033+
dil::tensor y;
2034+
auto group_dim = dim0 < dim1 ? dim0 : dim1;
2035+
auto groups = view_shape[group_dim];
2036+
dil::channel_shuffle_forward::compute(std::move(x), y, groups, group_dim);
2037+
return dbl::comm::gen_aten_tensor_by(std::move(y));
2038+
}
2039+
20242040
} // namespace cpu
20252041
} // namespace torch_ipex

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class AtenIpexCPUDev {
8181
static at::Tensor dil_view(const at::Tensor & self, at::IntArrayRef size);
8282
static at::Tensor dil_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
8383
static at::Tensor dil__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
84+
static at::Tensor dil_shuffle(const at::Tensor & self, at::IntArrayRef view_shape, int64_t dim0, int64_t dim1);
8485
};
8586

8687
} // namespace cpu

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,9 @@ void FusionPass(std::shared_ptr<Graph> &graph) {
308308
// Fuse conv with eltwise operator
309309
graph_rewrite::FuseConvolutionWithEltwise(graph);
310310

311+
// Fuse operators as shuffle
312+
graph_rewrite::FuseShuffle(graph);
313+
311314
// Pattern based fusion was lack of alias analysis
312315
// ??? It may either be too conservative or too aggressive ???
313316
// getSubgraphRewriter().runOnGraph(graph);

torch_ipex/csrc/jit/graph_rewrite.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,96 @@ std::unordered_map<std::string, c10::IValue> getConvParams(
5252
return calc_values;
5353
}
5454

55+
void FuseShuffle(std::shared_ptr<Graph>& graph) {
56+
std::string shuffle = R"(
57+
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
58+
%r = aten::view(%input, %view_shape)
59+
%r = aten::transpose(%r, %trans_dim0, %trans_dim1)
60+
%r = aten::contiguous(%r, %mem_format)
61+
%r = aten::view(%r, %flattern_shape)
62+
return (%r) )";
63+
64+
std::string shuffle_2d_fusion = R"(
65+
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
66+
%r = ipex::shuffle_2d(%input, %view_shape, %trans_dim0, %trans_dim1)
67+
return (%r) )";
68+
69+
auto filter_shuffle_2d_fusion = [] (
70+
const Match& match,
71+
const std::unordered_map<std::string, Value*>& vmap) {
72+
const auto& match_vmap = match.values_map;
73+
auto input_ = getIValue("input", match_vmap, vmap).value();
74+
if (!(input_.isTensor())) {
75+
return false;
76+
}
77+
auto view_shape_ = getIValue("view_shape", match_vmap, vmap).value();
78+
if (!(view_shape_.isIntList())) {
79+
return false;
80+
}
81+
auto trans_dim0_ = getIValue("trans_dim0", match_vmap, vmap).value();
82+
if (!(trans_dim0_.isInt())) {
83+
return false;
84+
}
85+
auto trans_dim1_ = getIValue("trans_dim1", match_vmap, vmap).value();
86+
if (!(trans_dim1_.isInt())) {
87+
return false;
88+
}
89+
auto flattern_shape_ = getIValue("flattern_shape", match_vmap, vmap).value();
90+
if (!(flattern_shape_.isInt())) {
91+
return false;
92+
}
93+
94+
auto trans_dim0_val = trans_dim0_.toInt();
95+
auto trans_dim1_val = trans_dim1_.toInt();
96+
auto dim0_val = trans_dim0_val < trans_dim1_val ? trans_dim0_val : trans_dim1_val;
97+
auto dim1_val = trans_dim0_val > trans_dim1_val ? trans_dim0_val : trans_dim1_val;
98+
// If the tranpose if not for groups. ex. [n, c1, c2, h, w] => [n, c2, c1, h, w]
99+
if ((dim1_val - dim0_val) != 1) {
100+
return false;
101+
}
102+
103+
auto input_val = input_.toTensor();
104+
auto view_shape_val = view_shape_.toIntVector();
105+
auto flattern_shape_val = flattern_shape_.toIntVector();
106+
// ex. [n, c, h, w] => [n, groups, c // groups, h, w]
107+
if ((input_val.ndimension() - view_shape_val.size()) != -1) {
108+
return false;
109+
}
110+
111+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim0_val >= 0);
112+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim1_val >= 0);
113+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim0_val + 1 < input_val.ndimension());
114+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim1_val + 1 < input_val.ndimension());
115+
if (view_shape_val[dim0_val] * view_shape_val[dim1_val] != input_val.size(dim0_val)) {
116+
return false;
117+
}
118+
119+
if (flattern_shape_val.size() != input_val.ndimension()) {
120+
return false;
121+
}
122+
123+
for (int i = 0; i < flattern_shape_val.size(); i++) {
124+
if (flattern_shape_val[i] != input_val.size(i)) {
125+
// [n, c, h, w] => view [n, groups, c // groups, h, w] => tranpose [n, c // groups, groups, h, w]
126+
// => view [n, -1, h, w]
127+
// or
128+
// view [n, c, h, w]
129+
if ((flattern_shape_val[i] != -1) || (i != dim0_val)) {
130+
return false;
131+
}
132+
}
133+
}
134+
135+
return true;
136+
};
137+
138+
SubgraphRewriter rewriter_shuffle_2d;
139+
rewriter_shuffle_2d.RegisterRewritePattern(
140+
shuffle,
141+
shuffle_2d_fusion);
142+
rewriter_shuffle_2d.runOnGraph(graph);
143+
}
144+
55145
void FuseConvolutionWithEltwise(std::shared_ptr<Graph>& graph) {
56146
std::string conv2d_swish_fusion = R"(
57147
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %groups:int):

torch_ipex/csrc/jit/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ c10::optional<IValue> getIValue(
2121
const std::unordered_map<std::string, Value*>& vmap);
2222
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
2323
void FuseConvolutionWithEltwise(std::shared_ptr<Graph>& graph);
24+
void FuseShuffle(std::shared_ptr<Graph>& graph);
2425

2526
} // namespace graph_rewrite_helper
2627
} // namespace jit

torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
#include <torch/csrc/jit/runtime/custom_operator.h>
55

66
#include "torch_ipex/csrc/utils.h"
7-
#include "cpu/FusionOPs.h"
8-
7+
#include "torch_ipex/csrc/cpu/FusionOPs.h"
8+
#include "torch_ipex/csrc/cpu/DevOPs.h"
99

1010
namespace torch {
1111
namespace jit {
@@ -24,6 +24,26 @@ at::Tensor toOptionalTensor(const IValue& v) {
2424
using namespace torch_ipex::cpu;
2525

2626
RegisterOperators op({
27+
Operator(
28+
"ipex::shuffle_2d(Tensor input, int[5] view_shape, int trans_dim0, int trans_dim1) -> Tensor",
29+
[] (const Node* node) ->Operation {
30+
if (torch_ipex::check_auto_dnnl()) {
31+
return [] (Stack& stack) {
32+
auto result = AtenIpexCPUDev::dil_shuffle(
33+
(std::move(peek(stack, 0, 4))).toTensor(),
34+
(std::move(peek(stack, 1, 4))).toIntVector(),
35+
(std::move(peek(stack, 2, 4))).toInt(),
36+
(std::move(peek(stack, 3, 4))).toInt());
37+
drop(stack, 4);
38+
pack(stack, std::move(result));
39+
return 0;
40+
};
41+
} else {
42+
TORCH_CHECK(false, "PyTorch native path not support shuffle fusion for 2d case");
43+
}
44+
},
45+
aliasAnalysisFromSchema()
46+
),
2747
Operator(
2848
"ipex::conv2d_relu(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
2949
[] (const Node* node) ->Operation {

0 commit comments

Comments
 (0)