Skip to content

Commit 49ed0e8

Browse files
authored
[QNN-EP] Fuse ChannelShuffle pattern (microsoft#24904)
### Description Fuse transposed channel shuffle pattern into QNN op -- ONNX does not have native ChannelShuffle op. ### Motivation and Context Improves performance on QNN EP.
1 parent 3ca8a49 commit 49ed0e8

File tree

4 files changed

+454
-1
lines changed

4 files changed

+454
-1
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h"
5+
6+
#include <gsl/gsl>
7+
#include <optional>
8+
#include <utility>
9+
#include <string>
10+
#include <array>
11+
#include <memory>
12+
#include <unordered_map>
13+
#include <vector>
14+
15+
#include "core/common/inlined_containers.h"
16+
#include "core/providers/qnn/builder/qnn_utils.h"
17+
#include "core/providers/qnn/builder/op_builder_factory.h"
18+
#include "core/providers/qnn/builder/qnn_node_group/utils.h"
19+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
20+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
21+
22+
namespace onnxruntime {
23+
namespace qnn {
24+
namespace {
25+
26+
constexpr char kAttrTransposePerm[] = "perm";
27+
constexpr char kOpChannelShuffle[] = "ChannelShuffle";
28+
constexpr char kOpTranspose[] = "Transpose";
29+
constexpr char kOpReshape[] = "Reshape";
30+
31+
using MapNodeToNodeUnit = std::unordered_map<const Node*, const NodeUnit*>;
32+
using MapNodeUnitToGroup = std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>;
33+
34+
std::optional<std::vector<int64_t>> GetTransposePerm(const NodeUnit& transpose) {
35+
if (transpose.OpType() != kOpTranspose) {
36+
return std::nullopt;
37+
}
38+
NodeAttrHelper helper(transpose.GetNode());
39+
return helper.Get(kAttrTransposePerm, std::vector<int64_t>());
40+
}
41+
42+
std::vector<int64_t> InvertTransposePerm(gsl::span<const int64_t> perm) {
43+
const size_t perm_size = perm.size();
44+
std::vector<int64_t> perm_inverse(perm_size);
45+
for (size_t i = 0; i < perm_size; ++i) {
46+
size_t j = gsl::narrow_cast<size_t>(perm[i]);
47+
perm_inverse[j] = gsl::narrow_cast<int64_t>(i);
48+
}
49+
return perm_inverse;
50+
}
51+
52+
bool IsCancelingTransposePermPair(
53+
std::optional<gsl::span<const int64_t>> perm1,
54+
std::optional<gsl::span<const int64_t>> perm2) {
55+
if (!perm1.has_value() || !perm2.has_value()) {
56+
return false;
57+
}
58+
if (perm1->size() != perm2->size()) {
59+
return false;
60+
}
61+
std::vector<int64_t> perm1_inverted_vector = InvertTransposePerm(*perm1);
62+
auto perm1_inverted = gsl::make_span<const int64_t>(
63+
perm1_inverted_vector.data(), perm1_inverted_vector.size());
64+
if (perm1_inverted != perm2.value()) {
65+
return false;
66+
}
67+
return true;
68+
}
69+
70+
/// @brief Match pattern: Transpose -> ChannelShuffle (Reshape -> Transpose -> Reshape) -> Transpose
71+
/// E.g.,: T(perm=[0, 2, 1, 3]) -> R(N, G, C/G, H, W) -> T(perm=[0, 1, 3, 2, 4]) -> R(N, C, H, W) -> T(perm=[0, 2, 1, 3])
72+
/// @param graph_viewer QNN graph viewer.
73+
/// @param transpose_head The first transpose node starting the pattern.
74+
/// @param node_to_node_unit Maps a Node to a NodeUnit.
75+
/// @param node_unit_to_qnn_node_group Maps a NodeUnit to a IQnnNodeGroup.
76+
/// @return The matched pattern as an array of NodeUnits if found, otherwise std::nullopt.
77+
/// @note This is ChannelShuffle with transpose wraps commonly seen ORT post partitioning.
78+
std::optional<std::array<const NodeUnit*, 5>> MatchChannelShufflePattern(
79+
const GraphViewer& graph_viewer,
80+
const NodeUnit* transpose_head,
81+
const MapNodeToNodeUnit& node_to_node_unit,
82+
const MapNodeUnitToGroup& node_unit_to_qnn_node_group) {
83+
// Helper function to get a single child of a specific type
84+
auto GetChildOfType = [&](const NodeUnit& node, std::string_view expect_type) -> const NodeUnit* {
85+
const std::array<std::string_view, 1> child_op_types{expect_type};
86+
const NodeUnit* child = GetOnlyChildOfType(
87+
graph_viewer, node, child_op_types, node_to_node_unit, node_unit_to_qnn_node_group);
88+
if (child == nullptr) {
89+
return nullptr;
90+
}
91+
if (child->OpType() != expect_type) {
92+
return nullptr;
93+
}
94+
if (child->UnitType() != NodeUnit::Type::SingleNode) {
95+
return nullptr;
96+
}
97+
return child;
98+
};
99+
100+
if (transpose_head->OpType() != kOpTranspose) {
101+
return std::nullopt;
102+
}
103+
if (transpose_head->UnitType() != NodeUnit::Type::SingleNode) {
104+
return std::nullopt;
105+
}
106+
const NodeUnit* reshape1 = GetChildOfType(*transpose_head, kOpReshape);
107+
if (reshape1 == nullptr) {
108+
return std::nullopt;
109+
}
110+
const NodeUnit* transpose = GetChildOfType(*reshape1, kOpTranspose);
111+
if (transpose == nullptr) {
112+
return std::nullopt;
113+
}
114+
const NodeUnit* reshape2 = GetChildOfType(*transpose, kOpReshape);
115+
if (reshape2 == nullptr) {
116+
return std::nullopt;
117+
}
118+
const NodeUnit* transpose_tail = GetChildOfType(*reshape2, kOpTranspose);
119+
if (transpose_tail == nullptr) {
120+
return std::nullopt;
121+
}
122+
return std::array<const NodeUnit*, 5>{transpose_head, reshape1, transpose, reshape2, transpose_tail};
123+
}
124+
125+
/// @brief Create or validate the QNN node of type ChannelShuffle.
126+
/// @param qnn_model_wrapper QNN model wrapper
127+
/// @param node_units The node units containing the nodes in pattern
128+
/// @param validate Whether to validate the QNN node
129+
/// @return Status
130+
Status CreateOrValidateOnQnn(
131+
QnnModelWrapper* qnn_model_wrapper,
132+
gsl::span<const NodeUnit* const> node_units,
133+
bool validate) {
134+
const NodeUnit* transpose_head = node_units[0];
135+
const NodeUnit* transpose_tail = node_units[4];
136+
const NodeUnitIODef& cs_input_def = transpose_head->Inputs()[0];
137+
const NodeUnitIODef& cs_output_def = transpose_tail->Outputs()[0];
138+
139+
std::vector<std::string> param_tensor_names;
140+
std::vector<Qnn_Param_t> param_tensors;
141+
{
142+
auto transpose_head_proto = transpose_head->GetNode().InputDefs()[0]->Shape();
143+
ORT_RETURN_IF_NOT(transpose_head_proto != nullptr, "Failed to get input shape proto.");
144+
TensorShape transpose_head_input_shape = utils::GetTensorProtoShape(*transpose_head_proto);
145+
const uint32_t channel_axis = static_cast<uint32_t>(transpose_head_input_shape.NumDimensions() - 1);
146+
Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT;
147+
axis_scalar.dataType = QNN_DATATYPE_UINT_32;
148+
axis_scalar.uint32Value = channel_axis;
149+
QnnParamWrapper param_wrapper(transpose_tail->Index(),
150+
transpose_tail->Name(),
151+
QNN_OP_CHANNEL_SHUFFLE_PARAM_AXIS,
152+
axis_scalar);
153+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param");
154+
param_tensor_names.push_back(param_wrapper.GetParamTensorName());
155+
param_tensors.push_back(param_wrapper.GetQnnParam());
156+
}
157+
{
158+
// Extract channel dimension from transpose (from channel last -> first)
159+
const NodeUnit* reshape1 = node_units[1];
160+
auto reshape1_proto = reshape1->GetNode().OutputDefs()[0]->Shape();
161+
ORT_RETURN_IF_NOT(reshape1_proto != nullptr, "Failed to get input shape proto.");
162+
TensorShape reshape1_output_shape = utils::GetTensorProtoShape(*reshape1_proto);
163+
Qnn_Scalar_t num_groups_scalar = QNN_SCALAR_INIT;
164+
num_groups_scalar.dataType = QNN_DATATYPE_UINT_32;
165+
num_groups_scalar.uint32Value = static_cast<uint32_t>(reshape1_output_shape.GetDims()[1]);
166+
QnnParamWrapper param_wrapper(transpose_tail->Index(),
167+
transpose_tail->Name(),
168+
QNN_OP_CHANNEL_SHUFFLE_PARAM_NUM_GROUPS,
169+
num_groups_scalar);
170+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param");
171+
param_tensor_names.push_back(param_wrapper.GetParamTensorName());
172+
param_tensors.push_back(param_wrapper.GetQnnParam());
173+
}
174+
175+
QnnTensorWrapper channel_shuffle_input;
176+
QnnTensorWrapper channel_shuffle_output;
177+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(cs_input_def, channel_shuffle_input));
178+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(cs_output_def, channel_shuffle_output));
179+
180+
// Note: Skipped QNN validation API due to its inconsistent behavior than creation API. Re-enable it when fixed.
181+
if (!validate) {
182+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(channel_shuffle_input)), "Failed to add input");
183+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(channel_shuffle_output)), "Failed to add output");
184+
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(transpose_tail->Name(),
185+
QNN_OP_PACKAGE_NAME_QTI_AISW,
186+
QNN_OP_CHANNEL_SHUFFLE,
187+
{cs_input_def.node_arg.Name()},
188+
{cs_output_def.node_arg.Name()},
189+
std::move(param_tensor_names),
190+
validate),
191+
"Failed to add fused " + std::string(kOpChannelShuffle) + " node.");
192+
}
193+
194+
return Status::OK();
195+
}
196+
197+
} // namespace
198+
199+
std::unique_ptr<IQnnNodeGroup> ChannelShuffleFusion::TryFusion(
200+
QnnModelWrapper& qnn_model_wrapper,
201+
const NodeUnit& transpose_head,
202+
const MapNodeToNodeUnit& node_to_node_unit,
203+
const MapNodeUnitToGroup& node_unit_to_qnn_node_group,
204+
[[maybe_unused]] const logging::Logger& logger) {
205+
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
206+
std::optional<std::array<const NodeUnit*, 5>> pattern = MatchChannelShufflePattern(
207+
graph_viewer, &transpose_head, node_to_node_unit, node_unit_to_qnn_node_group);
208+
if (!pattern.has_value()) {
209+
return nullptr;
210+
}
211+
const NodeUnit* reshape1 = pattern->at(1);
212+
const NodeUnit* transpose = pattern->at(2);
213+
const NodeUnit* reshape2 = pattern->at(3);
214+
215+
// Input shape to reshape1 must equal output shape of reshape2; and has rank > 2
216+
auto reshape1_input0_proto = reshape1->GetNode().InputDefs()[0]->Shape();
217+
auto reshape2_output_proto = reshape2->GetNode().OutputDefs()[0]->Shape();
218+
if (reshape1_input0_proto == nullptr || reshape2_output_proto == nullptr) {
219+
return nullptr;
220+
}
221+
TensorShape reshape1_input0_shape = utils::GetTensorProtoShape(*reshape1_input0_proto);
222+
TensorShape reshape2_output_shape = utils::GetTensorProtoShape(*reshape2_output_proto);
223+
if (reshape1_input0_shape.NumDimensions() != reshape2_output_shape.NumDimensions()) {
224+
return nullptr;
225+
}
226+
gsl::span<const int64_t> reshape1_input0_dims = reshape1_input0_shape.GetDims();
227+
gsl::span<const int64_t> reshape2_output_dims = reshape2_output_shape.GetDims();
228+
if (reshape1_input0_dims != reshape2_output_dims) {
229+
return nullptr;
230+
}
231+
232+
// Intermediate shape must be 1 rank higher than input shape
233+
auto reshape1_output_proto = reshape1->GetNode().OutputDefs()[0]->Shape();
234+
if (reshape1_output_proto == nullptr) {
235+
return nullptr;
236+
}
237+
TensorShape reshape1_output_shape = utils::GetTensorProtoShape(*reshape1_output_proto);
238+
239+
// Intermediate shape must split channels in groups only
240+
gsl::span<const int64_t> reshape1_output_dims = reshape1_output_shape.GetDims();
241+
if (reshape1_input0_dims[0] != reshape1_output_dims[0]) {
242+
return nullptr;
243+
}
244+
if (reshape1_output_dims.size() < 3) {
245+
return nullptr;
246+
}
247+
if (reshape1_input0_dims[1] != (reshape1_output_dims[1] * reshape1_output_dims[2])) {
248+
return nullptr;
249+
}
250+
if (reshape1_output_dims.size() != reshape1_input0_dims.size() + 1) {
251+
return nullptr;
252+
}
253+
size_t remaining_dims = reshape1_input0_dims.size() - 2;
254+
if (reshape1_output_dims.size() < remaining_dims + 3) {
255+
return nullptr;
256+
}
257+
for (size_t i = 0; i < remaining_dims; ++i) {
258+
if (reshape1_input0_dims[i + 2] != reshape1_output_dims[i + 3]) {
259+
return nullptr;
260+
}
261+
}
262+
263+
// Intermediate transpose must only permute channels
264+
std::optional<std::vector<int64_t>> perm = GetTransposePerm(*transpose);
265+
if (!perm.has_value()) {
266+
return nullptr;
267+
}
268+
std::vector<int64_t> perm_to_check = perm.value();
269+
std::swap(perm_to_check[1], perm_to_check[2]);
270+
std::vector<int64_t> perm_expected(perm_to_check.size());
271+
for (size_t i = 0; i < perm_expected.size(); ++i) {
272+
perm_expected[i] = static_cast<int64_t>(i);
273+
}
274+
if (perm_to_check != perm_expected) {
275+
return nullptr;
276+
}
277+
278+
// Check if the first and last transpose is a canceling transpose pair
279+
const NodeUnit* transpose_tail = pattern->at(4);
280+
std::optional<std::vector<int64_t>> perm_head = GetTransposePerm(transpose_head);
281+
if (!perm_head.has_value()) {
282+
return nullptr;
283+
}
284+
std::optional<std::vector<int64_t>> perm_tail = GetTransposePerm(*transpose_tail);
285+
if (!perm_tail.has_value()) {
286+
return nullptr;
287+
}
288+
if (!IsCancelingTransposePermPair(perm_head, perm_tail)) {
289+
return nullptr;
290+
}
291+
292+
if (CreateOrValidateOnQnn(&qnn_model_wrapper, pattern.value(), /*validate=*/true) != Status::OK()) {
293+
return nullptr;
294+
}
295+
return std::make_unique<ChannelShuffleFusion>(pattern.value());
296+
}
297+
298+
gsl::span<const NodeUnit* const> ChannelShuffleFusion::GetNodeUnits() const {
299+
return gsl::span<const NodeUnit* const>{node_units_.data(), node_units_.size()};
300+
}
301+
302+
Status ChannelShuffleFusion::IsSupported(
303+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
304+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true);
305+
}
306+
307+
Status ChannelShuffleFusion::AddToModelBuilder(
308+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
309+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false);
310+
}
311+
312+
} // namespace qnn
313+
} // namespace onnxruntime
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <gsl/gsl>
7+
#include <array>
8+
#include <memory>
9+
#include <unordered_map>
10+
#include <vector>
11+
12+
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
13+
#include "core/providers/qnn/ort_api.h"
14+
15+
namespace onnxruntime {
16+
namespace qnn {
17+
18+
class QnnModelWrapper;
19+
20+
/// <summary>
21+
/// Represents a fusion of pattern: Transpose -> ChannelShuffle (Reshape -> Transpose -> Reshape) -> Transpose
22+
/// </summary>
23+
class ChannelShuffleFusion : public IQnnNodeGroup {
24+
public:
25+
explicit ChannelShuffleFusion(gsl::span<const NodeUnit* const> node_units) {
26+
ORT_ENFORCE(node_units.size() == 5, "Pattern expect exactly 5 NodeUnits.");
27+
node_units_[0] = node_units[0];
28+
node_units_[1] = node_units[1];
29+
node_units_[2] = node_units[2];
30+
node_units_[3] = node_units[3];
31+
node_units_[4] = node_units[4];
32+
}
33+
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ChannelShuffleFusion);
34+
35+
Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
36+
Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
37+
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
38+
const NodeUnit* GetTargetNodeUnit() const override { return node_units_[0]; }
39+
std::string_view Type() const override { return "ChannelShuffleFusion"; }
40+
41+
/// <summary>
42+
/// Traverses graph to check if the given starting NodeUnit is part of a channel shuffle pattern.
43+
/// If so, returns a IQnnNodeGroup that contains the ChannelShuffle NodeUnits.
44+
/// </summary>
45+
static std::unique_ptr<IQnnNodeGroup> TryFusion(
46+
QnnModelWrapper& qnn_model_wrapper,
47+
const NodeUnit& transpose_node_unit,
48+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
49+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
50+
const logging::Logger& logger);
51+
52+
private:
53+
std::array<const NodeUnit*, 5> node_units_;
54+
};
55+
56+
} // namespace qnn
57+
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
1717
#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h"
1818
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
19+
#include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h"
20+
1921
#include "core/providers/qnn/builder/qnn_utils.h"
2022
#include "core/providers/qnn/ort_api.h"
2123

@@ -92,7 +94,7 @@ static std::unique_ptr<IQnnNodeGroup> TryQnnFusions(
9294
{"HardSigmoid", HardSigmoidMulFusion::TryFusion},
9395
{"Gemm", ReshapeGemmFusion::TryFusion},
9496
{"Mul", ScaleSoftmaxFusion::TryFusion},
95-
};
97+
{"Transpose", ChannelShuffleFusion::TryFusion}};
9698

9799
// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes).
98100
if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) {

0 commit comments

Comments
 (0)