|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h" |
| 5 | + |
| 6 | +#include <gsl/gsl> |
| 7 | +#include <optional> |
| 8 | +#include <utility> |
| 9 | +#include <string> |
| 10 | +#include <array> |
| 11 | +#include <memory> |
| 12 | +#include <unordered_map> |
| 13 | +#include <vector> |
| 14 | + |
| 15 | +#include "core/common/inlined_containers.h" |
| 16 | +#include "core/providers/qnn/builder/qnn_utils.h" |
| 17 | +#include "core/providers/qnn/builder/op_builder_factory.h" |
| 18 | +#include "core/providers/qnn/builder/qnn_node_group/utils.h" |
| 19 | +#include "core/providers/qnn/builder/qnn_model_wrapper.h" |
| 20 | +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" |
| 21 | + |
| 22 | +namespace onnxruntime { |
| 23 | +namespace qnn { |
| 24 | +namespace { |
| 25 | + |
| 26 | +constexpr char kAttrTransposePerm[] = "perm"; |
| 27 | +constexpr char kOpChannelShuffle[] = "ChannelShuffle"; |
| 28 | +constexpr char kOpTranspose[] = "Transpose"; |
| 29 | +constexpr char kOpReshape[] = "Reshape"; |
| 30 | + |
| 31 | +using MapNodeToNodeUnit = std::unordered_map<const Node*, const NodeUnit*>; |
| 32 | +using MapNodeUnitToGroup = std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>; |
| 33 | + |
| 34 | +std::optional<std::vector<int64_t>> GetTransposePerm(const NodeUnit& transpose) { |
| 35 | + if (transpose.OpType() != kOpTranspose) { |
| 36 | + return std::nullopt; |
| 37 | + } |
| 38 | + NodeAttrHelper helper(transpose.GetNode()); |
| 39 | + return helper.Get(kAttrTransposePerm, std::vector<int64_t>()); |
| 40 | +} |
| 41 | + |
| 42 | +std::vector<int64_t> InvertTransposePerm(gsl::span<const int64_t> perm) { |
| 43 | + const size_t perm_size = perm.size(); |
| 44 | + std::vector<int64_t> perm_inverse(perm_size); |
| 45 | + for (size_t i = 0; i < perm_size; ++i) { |
| 46 | + size_t j = gsl::narrow_cast<size_t>(perm[i]); |
| 47 | + perm_inverse[j] = gsl::narrow_cast<int64_t>(i); |
| 48 | + } |
| 49 | + return perm_inverse; |
| 50 | +} |
| 51 | + |
| 52 | +bool IsCancelingTransposePermPair( |
| 53 | + std::optional<gsl::span<const int64_t>> perm1, |
| 54 | + std::optional<gsl::span<const int64_t>> perm2) { |
| 55 | + if (!perm1.has_value() || !perm2.has_value()) { |
| 56 | + return false; |
| 57 | + } |
| 58 | + if (perm1->size() != perm2->size()) { |
| 59 | + return false; |
| 60 | + } |
| 61 | + std::vector<int64_t> perm1_inverted_vector = InvertTransposePerm(*perm1); |
| 62 | + auto perm1_inverted = gsl::make_span<const int64_t>( |
| 63 | + perm1_inverted_vector.data(), perm1_inverted_vector.size()); |
| 64 | + if (perm1_inverted != perm2.value()) { |
| 65 | + return false; |
| 66 | + } |
| 67 | + return true; |
| 68 | +} |
| 69 | + |
| 70 | +/// @brief Match pattern: Transpose -> ChannelShuffle (Reshape -> Transpose -> Reshape) -> Transpose |
| 71 | +/// E.g.,: T(perm=[0, 2, 1, 3]) -> R(N, G, C/G, H, W) -> T(perm=[0, 1, 3, 2, 4]) -> R(N, C, H, W) -> T(perm=[0, 2, 1, 3]) |
| 72 | +/// @param graph_viewer QNN graph viewer. |
| 73 | +/// @param transpose_head The first transpose node starting the pattern. |
| 74 | +/// @param node_to_node_unit Maps a Node to a NodeUnit. |
| 75 | +/// @param node_unit_to_qnn_node_group Maps a NodeUnit to a IQnnNodeGroup. |
| 76 | +/// @return The matched pattern as an array of NodeUnits if found, otherwise std::nullopt. |
| 77 | +/// @note This is ChannelShuffle with transpose wraps commonly seen ORT post partitioning. |
| 78 | +std::optional<std::array<const NodeUnit*, 5>> MatchChannelShufflePattern( |
| 79 | + const GraphViewer& graph_viewer, |
| 80 | + const NodeUnit* transpose_head, |
| 81 | + const MapNodeToNodeUnit& node_to_node_unit, |
| 82 | + const MapNodeUnitToGroup& node_unit_to_qnn_node_group) { |
| 83 | + // Helper function to get a single child of a specific type |
| 84 | + auto GetChildOfType = [&](const NodeUnit& node, std::string_view expect_type) -> const NodeUnit* { |
| 85 | + const std::array<std::string_view, 1> child_op_types{expect_type}; |
| 86 | + const NodeUnit* child = GetOnlyChildOfType( |
| 87 | + graph_viewer, node, child_op_types, node_to_node_unit, node_unit_to_qnn_node_group); |
| 88 | + if (child == nullptr) { |
| 89 | + return nullptr; |
| 90 | + } |
| 91 | + if (child->OpType() != expect_type) { |
| 92 | + return nullptr; |
| 93 | + } |
| 94 | + if (child->UnitType() != NodeUnit::Type::SingleNode) { |
| 95 | + return nullptr; |
| 96 | + } |
| 97 | + return child; |
| 98 | + }; |
| 99 | + |
| 100 | + if (transpose_head->OpType() != kOpTranspose) { |
| 101 | + return std::nullopt; |
| 102 | + } |
| 103 | + if (transpose_head->UnitType() != NodeUnit::Type::SingleNode) { |
| 104 | + return std::nullopt; |
| 105 | + } |
| 106 | + const NodeUnit* reshape1 = GetChildOfType(*transpose_head, kOpReshape); |
| 107 | + if (reshape1 == nullptr) { |
| 108 | + return std::nullopt; |
| 109 | + } |
| 110 | + const NodeUnit* transpose = GetChildOfType(*reshape1, kOpTranspose); |
| 111 | + if (transpose == nullptr) { |
| 112 | + return std::nullopt; |
| 113 | + } |
| 114 | + const NodeUnit* reshape2 = GetChildOfType(*transpose, kOpReshape); |
| 115 | + if (reshape2 == nullptr) { |
| 116 | + return std::nullopt; |
| 117 | + } |
| 118 | + const NodeUnit* transpose_tail = GetChildOfType(*reshape2, kOpTranspose); |
| 119 | + if (transpose_tail == nullptr) { |
| 120 | + return std::nullopt; |
| 121 | + } |
| 122 | + return std::array<const NodeUnit*, 5>{transpose_head, reshape1, transpose, reshape2, transpose_tail}; |
| 123 | +} |
| 124 | + |
| 125 | +/// @brief Create or validate the QNN node of type ChannelShuffle. |
| 126 | +/// @param qnn_model_wrapper QNN model wrapper |
| 127 | +/// @param node_units The node units containing the nodes in pattern |
| 128 | +/// @param validate Whether to validate the QNN node |
| 129 | +/// @return Status |
| 130 | +Status CreateOrValidateOnQnn( |
| 131 | + QnnModelWrapper* qnn_model_wrapper, |
| 132 | + gsl::span<const NodeUnit* const> node_units, |
| 133 | + bool validate) { |
| 134 | + const NodeUnit* transpose_head = node_units[0]; |
| 135 | + const NodeUnit* transpose_tail = node_units[4]; |
| 136 | + const NodeUnitIODef& cs_input_def = transpose_head->Inputs()[0]; |
| 137 | + const NodeUnitIODef& cs_output_def = transpose_tail->Outputs()[0]; |
| 138 | + |
| 139 | + std::vector<std::string> param_tensor_names; |
| 140 | + std::vector<Qnn_Param_t> param_tensors; |
| 141 | + { |
| 142 | + auto transpose_head_proto = transpose_head->GetNode().InputDefs()[0]->Shape(); |
| 143 | + ORT_RETURN_IF_NOT(transpose_head_proto != nullptr, "Failed to get input shape proto."); |
| 144 | + TensorShape transpose_head_input_shape = utils::GetTensorProtoShape(*transpose_head_proto); |
| 145 | + const uint32_t channel_axis = static_cast<uint32_t>(transpose_head_input_shape.NumDimensions() - 1); |
| 146 | + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; |
| 147 | + axis_scalar.dataType = QNN_DATATYPE_UINT_32; |
| 148 | + axis_scalar.uint32Value = channel_axis; |
| 149 | + QnnParamWrapper param_wrapper(transpose_tail->Index(), |
| 150 | + transpose_tail->Name(), |
| 151 | + QNN_OP_CHANNEL_SHUFFLE_PARAM_AXIS, |
| 152 | + axis_scalar); |
| 153 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); |
| 154 | + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); |
| 155 | + param_tensors.push_back(param_wrapper.GetQnnParam()); |
| 156 | + } |
| 157 | + { |
| 158 | + // Extract channel dimension from transpose (from channel last -> first) |
| 159 | + const NodeUnit* reshape1 = node_units[1]; |
| 160 | + auto reshape1_proto = reshape1->GetNode().OutputDefs()[0]->Shape(); |
| 161 | + ORT_RETURN_IF_NOT(reshape1_proto != nullptr, "Failed to get input shape proto."); |
| 162 | + TensorShape reshape1_output_shape = utils::GetTensorProtoShape(*reshape1_proto); |
| 163 | + Qnn_Scalar_t num_groups_scalar = QNN_SCALAR_INIT; |
| 164 | + num_groups_scalar.dataType = QNN_DATATYPE_UINT_32; |
| 165 | + num_groups_scalar.uint32Value = static_cast<uint32_t>(reshape1_output_shape.GetDims()[1]); |
| 166 | + QnnParamWrapper param_wrapper(transpose_tail->Index(), |
| 167 | + transpose_tail->Name(), |
| 168 | + QNN_OP_CHANNEL_SHUFFLE_PARAM_NUM_GROUPS, |
| 169 | + num_groups_scalar); |
| 170 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); |
| 171 | + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); |
| 172 | + param_tensors.push_back(param_wrapper.GetQnnParam()); |
| 173 | + } |
| 174 | + |
| 175 | + QnnTensorWrapper channel_shuffle_input; |
| 176 | + QnnTensorWrapper channel_shuffle_output; |
| 177 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(cs_input_def, channel_shuffle_input)); |
| 178 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(cs_output_def, channel_shuffle_output)); |
| 179 | + |
| 180 | + // Note: Skipped QNN validation API due to its inconsistent behavior than creation API. Re-enable it when fixed. |
| 181 | + if (!validate) { |
| 182 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(channel_shuffle_input)), "Failed to add input"); |
| 183 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(channel_shuffle_output)), "Failed to add output"); |
| 184 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(transpose_tail->Name(), |
| 185 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 186 | + QNN_OP_CHANNEL_SHUFFLE, |
| 187 | + {cs_input_def.node_arg.Name()}, |
| 188 | + {cs_output_def.node_arg.Name()}, |
| 189 | + std::move(param_tensor_names), |
| 190 | + validate), |
| 191 | + "Failed to add fused " + std::string(kOpChannelShuffle) + " node."); |
| 192 | + } |
| 193 | + |
| 194 | + return Status::OK(); |
| 195 | +} |
| 196 | + |
| 197 | +} // namespace |
| 198 | + |
| 199 | +std::unique_ptr<IQnnNodeGroup> ChannelShuffleFusion::TryFusion( |
| 200 | + QnnModelWrapper& qnn_model_wrapper, |
| 201 | + const NodeUnit& transpose_head, |
| 202 | + const MapNodeToNodeUnit& node_to_node_unit, |
| 203 | + const MapNodeUnitToGroup& node_unit_to_qnn_node_group, |
| 204 | + [[maybe_unused]] const logging::Logger& logger) { |
| 205 | + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); |
| 206 | + std::optional<std::array<const NodeUnit*, 5>> pattern = MatchChannelShufflePattern( |
| 207 | + graph_viewer, &transpose_head, node_to_node_unit, node_unit_to_qnn_node_group); |
| 208 | + if (!pattern.has_value()) { |
| 209 | + return nullptr; |
| 210 | + } |
| 211 | + const NodeUnit* reshape1 = pattern->at(1); |
| 212 | + const NodeUnit* transpose = pattern->at(2); |
| 213 | + const NodeUnit* reshape2 = pattern->at(3); |
| 214 | + |
| 215 | + // Input shape to reshape1 must equal output shape of reshape2; and has rank > 2 |
| 216 | + auto reshape1_input0_proto = reshape1->GetNode().InputDefs()[0]->Shape(); |
| 217 | + auto reshape2_output_proto = reshape2->GetNode().OutputDefs()[0]->Shape(); |
| 218 | + if (reshape1_input0_proto == nullptr || reshape2_output_proto == nullptr) { |
| 219 | + return nullptr; |
| 220 | + } |
| 221 | + TensorShape reshape1_input0_shape = utils::GetTensorProtoShape(*reshape1_input0_proto); |
| 222 | + TensorShape reshape2_output_shape = utils::GetTensorProtoShape(*reshape2_output_proto); |
| 223 | + if (reshape1_input0_shape.NumDimensions() != reshape2_output_shape.NumDimensions()) { |
| 224 | + return nullptr; |
| 225 | + } |
| 226 | + gsl::span<const int64_t> reshape1_input0_dims = reshape1_input0_shape.GetDims(); |
| 227 | + gsl::span<const int64_t> reshape2_output_dims = reshape2_output_shape.GetDims(); |
| 228 | + if (reshape1_input0_dims != reshape2_output_dims) { |
| 229 | + return nullptr; |
| 230 | + } |
| 231 | + |
| 232 | + // Intermediate shape must be 1 rank higher than input shape |
| 233 | + auto reshape1_output_proto = reshape1->GetNode().OutputDefs()[0]->Shape(); |
| 234 | + if (reshape1_output_proto == nullptr) { |
| 235 | + return nullptr; |
| 236 | + } |
| 237 | + TensorShape reshape1_output_shape = utils::GetTensorProtoShape(*reshape1_output_proto); |
| 238 | + |
| 239 | + // Intermediate shape must split channels in groups only |
| 240 | + gsl::span<const int64_t> reshape1_output_dims = reshape1_output_shape.GetDims(); |
| 241 | + if (reshape1_input0_dims[0] != reshape1_output_dims[0]) { |
| 242 | + return nullptr; |
| 243 | + } |
| 244 | + if (reshape1_output_dims.size() < 3) { |
| 245 | + return nullptr; |
| 246 | + } |
| 247 | + if (reshape1_input0_dims[1] != (reshape1_output_dims[1] * reshape1_output_dims[2])) { |
| 248 | + return nullptr; |
| 249 | + } |
| 250 | + if (reshape1_output_dims.size() != reshape1_input0_dims.size() + 1) { |
| 251 | + return nullptr; |
| 252 | + } |
| 253 | + size_t remaining_dims = reshape1_input0_dims.size() - 2; |
| 254 | + if (reshape1_output_dims.size() < remaining_dims + 3) { |
| 255 | + return nullptr; |
| 256 | + } |
| 257 | + for (size_t i = 0; i < remaining_dims; ++i) { |
| 258 | + if (reshape1_input0_dims[i + 2] != reshape1_output_dims[i + 3]) { |
| 259 | + return nullptr; |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + // Intermediate transpose must only permute channels |
| 264 | + std::optional<std::vector<int64_t>> perm = GetTransposePerm(*transpose); |
| 265 | + if (!perm.has_value()) { |
| 266 | + return nullptr; |
| 267 | + } |
| 268 | + std::vector<int64_t> perm_to_check = perm.value(); |
| 269 | + std::swap(perm_to_check[1], perm_to_check[2]); |
| 270 | + std::vector<int64_t> perm_expected(perm_to_check.size()); |
| 271 | + for (size_t i = 0; i < perm_expected.size(); ++i) { |
| 272 | + perm_expected[i] = static_cast<int64_t>(i); |
| 273 | + } |
| 274 | + if (perm_to_check != perm_expected) { |
| 275 | + return nullptr; |
| 276 | + } |
| 277 | + |
| 278 | + // Check if the first and last transpose is a canceling transpose pair |
| 279 | + const NodeUnit* transpose_tail = pattern->at(4); |
| 280 | + std::optional<std::vector<int64_t>> perm_head = GetTransposePerm(transpose_head); |
| 281 | + if (!perm_head.has_value()) { |
| 282 | + return nullptr; |
| 283 | + } |
| 284 | + std::optional<std::vector<int64_t>> perm_tail = GetTransposePerm(*transpose_tail); |
| 285 | + if (!perm_tail.has_value()) { |
| 286 | + return nullptr; |
| 287 | + } |
| 288 | + if (!IsCancelingTransposePermPair(perm_head, perm_tail)) { |
| 289 | + return nullptr; |
| 290 | + } |
| 291 | + |
| 292 | + if (CreateOrValidateOnQnn(&qnn_model_wrapper, pattern.value(), /*validate=*/true) != Status::OK()) { |
| 293 | + return nullptr; |
| 294 | + } |
| 295 | + return std::make_unique<ChannelShuffleFusion>(pattern.value()); |
| 296 | +} |
| 297 | + |
| 298 | +gsl::span<const NodeUnit* const> ChannelShuffleFusion::GetNodeUnits() const { |
| 299 | + return gsl::span<const NodeUnit* const>{node_units_.data(), node_units_.size()}; |
| 300 | +} |
| 301 | + |
| 302 | +Status ChannelShuffleFusion::IsSupported( |
| 303 | + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { |
| 304 | + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true); |
| 305 | +} |
| 306 | + |
| 307 | +Status ChannelShuffleFusion::AddToModelBuilder( |
| 308 | + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { |
| 309 | + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false); |
| 310 | +} |
| 311 | + |
| 312 | +} // namespace qnn |
| 313 | +} // namespace onnxruntime |
0 commit comments