Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,41 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node&
return true;
}

bool ClipNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
// Clip can have 1, 2, or 3 DQ inputs:
// - 1 DQ: only data input is quantized
// - 2 DQ: data and min or max are quantized
// - 3 DQ: data, min, and max are all quantized
const size_t num_dq_nodes = dq_nodes.size();
if (num_dq_nodes < 1 || num_dq_nodes > 3) {
return false;
}

if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, static_cast<int>(num_dq_nodes))) {
return false;
}

int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

if (dt_input != dt_output) {
return false;
}

// 16-bit int types must be explicitly allowed.
if (!allow_16bit_ && Is16BitIntType(dt_input)) {
return false;
}

if (!allow_4bit_ && Is4BitIntType(dt_input)) {
return false;
}

return true;
}

bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ class UnaryNodeGroupSelector : public NodeGroupSelector {
bool allow_4bit_;
};

class ClipNodeGroupSelector : public NodeGroupSelector {
public:
explicit ClipNodeGroupSelector(bool allow_16bit = true, bool allow_4bit = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}

private:
bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;

bool allow_16bit_;
bool allow_4bit_;
};

// 2 DQ nodes providing input -> node -> Q
class BinaryNodeGroupSelector : public NodeGroupSelector {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
{"Neg", {}},
{"DepthToSpace", {}},
{"SpaceToDepth", {}},
{"Clip", {}},
{"LpNormalization", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetClipOpVersionsMap() {
return {{"Clip", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
return {{"Add", {}},
{"Div", {}},
Expand Down Expand Up @@ -168,19 +170,26 @@ void RegisterMiscSelectors(Selectors& qdq_selectors) {
}

void RegisterDropDQSelectors(Selectors& qdq_selectors) {
/* register selectors for ops that have a sigle DQ -> node */
/* register selectors for ops that have a single DQ -> node */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<DropDQNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetDropDQOpVersionsMap(),
std::move(selector));
}

void RegisterUnarySelectors(Selectors& qdq_selectors) {
/* regsiter selectors for unary ops */
/* register selectors for unary ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<UnaryNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetUnaryOpVersionsMap(),
std::move(selector));
}

void RegisterClipSelector(Selectors& qdq_selectors) {
/* register selector for Clip op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ClipNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetClipOpVersionsMap(),
std::move(selector));
}

void RegisterBinarySelectors(Selectors& qdq_selectors) {
/* register selectors for binary ops */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<BinaryNodeGroupSelector>();
Expand Down Expand Up @@ -305,6 +314,7 @@ void SelectorManager::CreateSelectors() {
RegisterMiscSelectors(qdq_selectors_);
RegisterDropDQSelectors(qdq_selectors_);
RegisterUnarySelectors(qdq_selectors_);
RegisterClipSelector(qdq_selectors_);
RegisterBinarySelectors(qdq_selectors_);
RegisterVariadicSelectors(qdq_selectors_);
RegisterSplitSelector(qdq_selectors_);
Expand Down
146 changes: 101 additions & 45 deletions onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ClipOpBuilder : public BaseOpBuilder {
bool do_op_validation) const override ORT_MUST_USE_RESULT;

private:
Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
Status ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
};

static Status ProcessClipMinMax(QnnModelWrapper& qnn_model_wrapper,
Expand All @@ -41,56 +41,112 @@ static Status ProcessClipMinMax(QnnModelWrapper& qnn_model_wrapper,
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, input_info));
assert(input_info.is_initializer); // Checked by ExplicitOpCheck().
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, val_bytes));
switch (input_info.qnn_data_type) {
case QNN_DATATYPE_INT_8: {
float_value = static_cast<float>(*reinterpret_cast<int8_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_INT_16: {
float_value = static_cast<float>(*reinterpret_cast<int16_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_INT_32: {
float_value = static_cast<float>(*reinterpret_cast<int32_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_INT_64: {
float_value = static_cast<float>(*reinterpret_cast<int64_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_UINT_8: {
float_value = static_cast<float>(*val_bytes.data());
break;
}
case QNN_DATATYPE_UINT_16: {
float_value = static_cast<float>(*reinterpret_cast<uint16_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_UINT_32: {
float_value = static_cast<float>(*reinterpret_cast<uint32_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_UINT_64: {
float_value = static_cast<float>(*reinterpret_cast<uint64_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_FLOAT_16: {
MLFloat16 fp16_value = *reinterpret_cast<const MLFloat16*>(val_bytes.data());
float_value = fp16_value.ToFloat();
break;

// If the input is quantized, we need to dequantize it
if (input.quant_param.has_value()) {
ORT_RETURN_IF_NOT(input_info.quant_param.IsPerTensor(),
"Clip's min/max must use per-tensor quantization");
const Qnn_QuantizeParams_t& quant_param = input_info.quant_param.Get();

switch (input_info.qnn_data_type) {
case QNN_DATATYPE_SFIXED_POINT_8: {
int8_t quantized_value = *reinterpret_cast<int8_t*>(val_bytes.data());
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
static_cast<double>(quantized_value)));
break;
}
case QNN_DATATYPE_SFIXED_POINT_16: {
int16_t quantized_value = *reinterpret_cast<int16_t*>(val_bytes.data());
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
static_cast<double>(quantized_value)));
break;
}
case QNN_DATATYPE_SFIXED_POINT_32: {
int32_t quantized_value = *reinterpret_cast<int32_t*>(val_bytes.data());
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
static_cast<double>(quantized_value)));
break;
}
case QNN_DATATYPE_UFIXED_POINT_8: {
uint8_t quantized_value = *val_bytes.data();
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
static_cast<double>(quantized_value)));
break;
}
case QNN_DATATYPE_UFIXED_POINT_16: {
uint16_t quantized_value = *reinterpret_cast<uint16_t*>(val_bytes.data());
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
static_cast<double>(quantized_value)));
break;
}
case QNN_DATATYPE_UFIXED_POINT_32: {
uint32_t quantized_value = *reinterpret_cast<uint32_t*>(val_bytes.data());
float_value = static_cast<float>(utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
quant_param.scaleOffsetEncoding.scale,
static_cast<double>(quantized_value)));
break;
}
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Quantized min/max input data type not supported.");
}
case QNN_DATATYPE_FLOAT_32: {
float_value = *reinterpret_cast<const float*>(val_bytes.data());
break;
} else {
// Non-quantized input, just cast to float
switch (input_info.qnn_data_type) {
case QNN_DATATYPE_INT_8: {
float_value = static_cast<float>(*reinterpret_cast<int8_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_INT_16: {
float_value = static_cast<float>(*reinterpret_cast<int16_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_INT_32: {
float_value = static_cast<float>(*reinterpret_cast<int32_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_INT_64: {
float_value = static_cast<float>(*reinterpret_cast<int64_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_UINT_8: {
float_value = static_cast<float>(*val_bytes.data());
break;
}
case QNN_DATATYPE_UINT_16: {
float_value = static_cast<float>(*reinterpret_cast<uint16_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_UINT_32: {
float_value = static_cast<float>(*reinterpret_cast<uint32_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_UINT_64: {
float_value = static_cast<float>(*reinterpret_cast<uint64_t*>(val_bytes.data()));
break;
}
case QNN_DATATYPE_FLOAT_16: {
MLFloat16 fp16_value = *reinterpret_cast<const MLFloat16*>(val_bytes.data());
float_value = fp16_value.ToFloat();
break;
}
case QNN_DATATYPE_FLOAT_32: {
float_value = *reinterpret_cast<const float*>(val_bytes.data());
break;
}
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Non-quantized min/max input data type not supported.");
}
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "min/max input data type not supported.");
}

return Status::OK();
}

Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
Status ClipOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
if (node_unit.Inputs().size() > 1) {
const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name();
if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) {
Expand All @@ -112,7 +168,7 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
std::vector<std::string>& input_names,
bool do_op_validation) const {
if (do_op_validation) {
ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit));
ORT_RETURN_IF_ERROR(ExplicitOpCheck(qnn_model_wrapper, node_unit));
}

return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names);
Expand Down
Loading