Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
// These produce int64 indices output, which can't be quantized, so there's no downstream Q node.
static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() {
return {{"ArgMax", {}},
{"ArgMin", {}}};
{"ArgMin", {}},
{"NonZero", {}}};
}

static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
{
CreateInverseOpBuilder("Inverse", *this);
}

{
CreateNonZeroOpBuilder("NonZero", *this);
}
}

const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,7 @@ void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_

void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateNonZeroOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class BaseOpBuilder : public IOpBuilder {
{"Max", QNN_OP_ELEMENT_WISE_MAXIMUM},
{"Min", QNN_OP_ELEMENT_WISE_MINIMUM},
{"Neg", QNN_OP_ELEMENT_WISE_NEG},
{"NonZero", QNN_OP_NON_ZERO},
{"Not", QNN_OP_ELEMENT_WISE_NOT},
{"Or", QNN_OP_ELEMENT_WISE_OR},
{"Pow", QNN_OP_ELEMENT_WISE_POWER},
Expand Down
117 changes: 117 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/nonzero_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <string>
#include <vector>

#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_utils.h"

namespace onnxruntime {
namespace qnn {

class NonZeroOpBuilder : public BaseOpBuilder {
public:
NonZeroOpBuilder() : BaseOpBuilder("NonZeroOpBuilder") {}

protected:
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

Status NonZeroOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
// Handle a corner case explicitly, which can pass backend validation but in fact not executable.
const std::vector<uint32_t>& input_shape = qnn_model_wrapper.GetQnnTensorWrapper(input_names[0]).GetTensorDims();
for (const uint32_t& dim : input_shape) {
ORT_RETURN_IF(dim == 0, "QNN does not support NonZero with empty input.");
}

const auto& output = node_unit.Outputs()[0];
const std::string& output_name = output.node_arg.Name();

TensorInfo output_info = {};
Status status = qnn_model_wrapper.GetTensorInfo(output, output_info);
if (!status.IsOK()) {
LOGS(logger, ERROR) << "Encountering NonZero " << node_unit.Name() << " which has dynamically shaped output tensor."
<< "QNN supports NonZero by allocating maximum possible size (i.e., all elements != 0), "
<< "and fills only the detected nonzero elements in the output tensor."
<< "The model must be preproceesed to eliminate the dynamic shapes first for QNN to support.";
return status;
}

// ONNX NonZero has shape [input_rank, #input_elements].
uint32_t rank = output_info.shape[0];
uint32_t num_elements = output_info.shape[1];

// QNN NonZero has shape [#input elements, input_rank], and thus an extra Transpose must be inserted afterwards.
const std::string transpose_input_name = utils::GetUniqueName(output_name, +"_transpose");
const std::vector<uint32_t> transpose_input_shape{num_elements, rank};
QnnTensorWrapper output_tensorwrapper(transpose_input_name,
QNN_TENSOR_TYPE_NATIVE,
output_info.qnn_data_type,
output_info.quant_param.Copy(),
std::vector<uint32_t>(transpose_input_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
std::move(input_names),
{transpose_input_name},
{},
do_op_validation),
"Failed to add NonZero node.");

// NonZero's output is indices which is INT64 dtype. If it happens to be graph output as well, add a Cast node to
// cast the dtype back to INT64 since wrapper construction implicitly changes the dtype to INT32.
const bool is_cast_required = output_info.qnn_data_type == QNN_DATATYPE_INT_64 &&
qnn_model_wrapper.IsGraphOutput(output_name);
const std::string transpose_output_name = is_cast_required ? utils::GetUniqueName(output_name, "_cast") : output_name;

std::vector<uint32_t> transpose_perm{1, 0};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
transpose_input_name,
transpose_output_name,
transpose_input_shape,
transpose_perm,
output_info.shape,
output_info.qnn_data_type,
output_info.quant_param,
do_op_validation,
false,
false));

if (is_cast_required) {
QnnTensorWrapper cast_output_tensorwrapper(output_name,
QNN_TENSOR_TYPE_APP_READ,
output_info.qnn_data_type,
output_info.quant_param.Copy(),
std::vector<uint32_t>(output_info.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output_tensorwrapper)),
"Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_CAST),
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_CAST,
{transpose_output_name},
{output_name},
{},
do_op_validation),
"Failed to add node");
}

return Status::OK();
}

void CreateNonZeroOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.AddOpBuilder(op_type, std::make_unique<NonZeroOpBuilder>());
}

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ...onnx_model import ONNXModel
from .fusion_lpnorm import FusionLpNormalization
from .fusion_spacetodepth import FusionSpaceToDepth
from .shape_nonzero import ShapeNonZero


def qnn_preprocess_model(
Expand Down Expand Up @@ -108,6 +109,9 @@ def qnn_preprocess_model(
if exclude_initializer_from_input:
modified |= remove_initializer_from_input(onnx_model.model)

# Shape dynamic-shaped NonZero.
modified |= ShapeNonZero(onnx_model).apply()

# Fuse Erf sequence into a single Gelu
fusion_gelu = FusionGelu(onnx_model)
if fusion_gelu.apply():
Expand Down Expand Up @@ -166,7 +170,7 @@ def qnn_preprocess_model(
if modified:
onnx_model.topological_sort()
onnx.save_model(
model,
onnx_model.model,
model_output,
save_as_external_data=save_as_external_data,
all_tensors_to_one_file=all_tensors_to_one_file,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""Define NonZero shape inference."""

import logging

import numpy as np
import onnx

from ... import fusions, onnx_model


class ShapeNonZero(fusions.Fusion):
"""Shape inference for NonZero.

NonZero node produces dynamically shaped output tensor, causing the tensor shapes of following nodes undetermined
as well. QNN expects NonZero having its shape set to maximum size (i.e., number of total input elements) and let
runtime handle the dynamic shape later.
"""

def __init__(self, model: onnx_model.ONNXModel):
"""Initialize.
Args:
model: An onnx_model.ONNXModel instance.
"""
super().__init__(model, "", "NonZero")

def fuse(
self,
node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""Infer shape for NonZero.

Args:
node: An onnx.NodeProto matching the specified search type (i.e., NonZero).
input_name_to_nodes: A dict mapping tensor name to consumed nodes.
output_name_to_node: A dict mapping tensor name to produced node.

Returns:
A bool indicating whether the node is updated.
"""
logging.warning(
"The model contains a NonZero node which produces a dynamically shaped output tensor."
"Following QNN requirements, its output shape will be deliberately set to the maximum size."
)

if (input_tensor_type := self.model.get_tensor_type(node.input[0])) is None or (
output_tensor_type := self.model.get_tensor_type(node.output[0])
) is None:
return False

if not (input_tensor_shape := self.tensor_shape_to_list(input_tensor_type)):
return False

if not all(isinstance(dim, int) for dim in input_tensor_shape):
return False

output_tensor_type.shape.dim[1].dim_value = np.prod(input_tensor_shape)
return True

def apply(self) -> bool:
"""Apply fusion.

This method is overridden to execute shape inference again since NonZero will have fixed shape.

Returns:
A bool indicating whether the model is updated.
"""
input_name_to_nodes = self.model.input_name_to_nodes()
output_name_to_node = self.model.output_name_to_node()

updated = False
for node in self.model.nodes():
if node.op_type == self.search_op_type:
updated |= self.fuse(node, input_name_to_nodes, output_name_to_node)

if updated:
self.model.model = onnx.shape_inference.infer_shapes(self.model.model)

return updated
Loading