Skip to content

Commit eade5fe

Browse files
authored
Add WhereDummyDq Transformer to form Node Unit (microsoft#25576)
### Description - Add a GraphTransformer `WhereDummyDq` to insert dummy DequantizeLinear on Where node's initializer input to form a Node Unit when Where node has one DQ and one scalar initializer input - Add corresponding unit test for the optimization ### Motivation and Context - To reduce the additional Dequantize and Quantize nodes, we would like to pass `WhereNodeGroupSelector::Check`.
1 parent f91d24c commit eade5fe

File tree

4 files changed

+271
-0
lines changed

4 files changed

+271
-0
lines changed

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
6868
#endif
6969
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
70+
#include "core/optimizer/qdq_transformer/where_dummy_dq.h"
7071
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
7172
#include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h"
7273
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
@@ -271,6 +272,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
271272
// It runs unconditionally in InferenceSession::TransformGraph() prior to Level1 optimizers.
272273
// We also put it here with other Level1 optimizers so that it can fix things up after their changes.
273274
transformers.emplace_back(std::make_unique<EnsureUniqueDQForNodeUnit>());
275+
transformers.emplace_back(std::make_unique<WhereDummyDq>());
274276
}
275277

276278
// add __backwardpass attribute to nodes after YieldOp, ROCm-only
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/optimizer/qdq_transformer/where_dummy_dq.h"
5+
6+
#include "core/framework/tensorprotoutils.h"
7+
#include "core/common/common.h"
8+
#include "core/util/qmath.h"
9+
#include "core/graph/graph_utils.h"
10+
#include "core/graph/graph_viewer.h"
11+
#include "core/optimizer/initializer.h"
12+
#include "core/optimizer/utils.h"
13+
#include "core/optimizer/qdq_transformer/qdq_util.h"
14+
15+
namespace onnxruntime {
16+
bool WhereDummyDq::SatisfyCondition(const Graph& graph, const Node& node) const {
17+
if (!(node.OpType() == "Where")) {
18+
return false;
19+
}
20+
const auto& where_inputs = node.InputDefs();
21+
const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name());
22+
const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name());
23+
24+
bool is_p1_dq = (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName);
25+
bool is_p2_dq = (parent_node_2 && parent_node_2->OpType() == QDQ::DQOpName);
26+
27+
// WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input
28+
if (is_p1_dq && !parent_node_2) {
29+
return (where_inputs[2]->Shape()->dim_size() == 0);
30+
}
31+
if (!parent_node_1 && is_p2_dq) {
32+
return (where_inputs[1]->Shape()->dim_size() == 0);
33+
}
34+
return false;
35+
}
36+
37+
Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const {
38+
const auto& where_inputs = node.InputDefs();
39+
const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name());
40+
const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name());
41+
42+
// With SatisfyCondition, we must have one DQ and one initializer
43+
const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2;
44+
int const_idx = parent_node_1 ? 2 : 1;
45+
46+
const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr;
47+
graph.GetInitializedTensor(dq_node->InputDefs()[1]->Name(), dq_node_scale_proto);
48+
const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr;
49+
graph.GetInitializedTensor(dq_node->InputDefs()[2]->Name(), dq_node_zp_proto);
50+
51+
// Dummy data initializer.
52+
ONNX_NAMESPACE::TensorProto dummy_data_proto;
53+
dummy_data_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_data"));
54+
// Set data type to dq node's zp dtype
55+
dummy_data_proto.set_data_type(dq_node_zp_proto->data_type());
56+
57+
// Dummy zero point initializer.
58+
ONNX_NAMESPACE::TensorProto dummy_zp_proto;
59+
dummy_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_zp"));
60+
dummy_zp_proto.set_data_type(dq_node_zp_proto->data_type());
61+
62+
switch (dummy_zp_proto.data_type()) {
63+
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
64+
int8_t zp = 0;
65+
int8_t dummy_data = 1;
66+
dummy_zp_proto.set_raw_data(&zp, 1);
67+
dummy_data_proto.set_raw_data(&dummy_data, 1);
68+
break;
69+
}
70+
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
71+
uint8_t zp = 0;
72+
uint8_t dummy_data = 1;
73+
dummy_zp_proto.set_raw_data(&zp, 1);
74+
dummy_data_proto.set_raw_data(&dummy_data, 1);
75+
break;
76+
}
77+
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
78+
int16_t zp = 0;
79+
int16_t dummy_data = 1;
80+
dummy_zp_proto.set_raw_data(&zp, 2);
81+
dummy_data_proto.set_raw_data(&dummy_data, 2);
82+
break;
83+
}
84+
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
85+
uint16_t zp = 0;
86+
uint16_t dummy_data = 1;
87+
dummy_zp_proto.set_raw_data(&zp, 2);
88+
dummy_data_proto.set_raw_data(&dummy_data, 2);
89+
break;
90+
}
91+
default:
92+
LOGS(logger, WARNING) << "Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16";
93+
return Status::OK();
94+
}
95+
96+
// Set dummy scale to the original value
97+
const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr;
98+
graph.GetInitializedTensor(where_inputs[const_idx]->Name(), const_node_data_proto);
99+
Initializer initializer(graph, *const_node_data_proto, graph.ModelPath());
100+
if (dq_node_scale_proto->data_type() != const_node_data_proto->data_type()) {
101+
// WhereDummyDq fills the const value to the dummy DQ's scale
102+
LOGS(logger, WARNING) << "Currently only support existing DQ's scale with same datatype as scalar";
103+
return Status::OK();
104+
}
105+
106+
// Dummy scale initializer.
107+
ONNX_NAMESPACE::TensorProto dummy_scale_proto;
108+
dummy_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_scale"));
109+
dummy_scale_proto.set_data_type(dq_node_scale_proto->data_type());
110+
switch (initializer.data_type()) {
111+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
112+
float* where_const_scalar = initializer.data<float>();
113+
dummy_scale_proto.set_raw_data(where_const_scalar, sizeof(float));
114+
break;
115+
}
116+
default:
117+
LOGS(logger, WARNING) << "Currently support scalar with FLOAT";
118+
return Status::OK();
119+
}
120+
121+
// Start editing the graph
122+
NodeArg& dummy_data_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_data_proto);
123+
NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_scale_proto);
124+
NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_zp_proto);
125+
126+
ONNX_NAMESPACE::TypeProto dummy_dq_type_proto = utils::TypeProtoFromTensorProto(*const_node_data_proto);
127+
dummy_dq_type_proto.mutable_tensor_type()->set_elem_type(const_node_data_proto->data_type());
128+
NodeArg& dummy_dq_arg =
129+
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), &dummy_dq_type_proto);
130+
Node& dummy_dq_node =
131+
graph.AddNode(
132+
graph.GenerateNodeArgName(node.Name() + "_dummy_dq"),
133+
QDQ::DQOpName,
134+
"DeQuantizeLinear from WhereDummyDq GraphTransformer",
135+
{&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg},
136+
{&dummy_dq_arg},
137+
nullptr,
138+
dq_node->Domain());
139+
140+
node.MutableInputDefs()[const_idx] = &dummy_dq_arg;
141+
if (graph.GetConsumerNodes(where_inputs[const_idx]->Name()).size() == 0) {
142+
graph.RemoveInitializedTensor(where_inputs[const_idx]->Name());
143+
}
144+
graph.AddEdge(dummy_dq_node.Index(), node.Index(), 0, const_idx);
145+
modified = true;
146+
147+
return Status::OK();
148+
}
149+
150+
Status WhereDummyDq::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
151+
const GraphViewer graph_viewer{graph};
152+
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
153+
for (const auto node_idx : node_indices) {
154+
auto* node_ptr = graph.GetNode(node_idx);
155+
if (!node_ptr) {
156+
continue;
157+
}
158+
159+
Node& node = *node_ptr;
160+
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
161+
162+
if (this->SatisfyCondition(graph, node)) {
163+
ORT_RETURN_IF_ERROR(WhereDummyDq::InsertDummyDQ(node, graph, modified, logger));
164+
}
165+
}
166+
167+
return Status::OK();
168+
}
169+
} // namespace onnxruntime
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/optimizer/graph_transformer.h"
7+
8+
namespace onnxruntime {
9+
10+
/**
11+
@Class WhereDummyDq
12+
13+
Graph transformer that inserts a dummy DQ on Where node's initializer input
14+
to form Node Unit when Where node has one DQ and one scalar initializer input
15+
*/
16+
class WhereDummyDq : public GraphTransformer {
17+
public:
18+
WhereDummyDq() noexcept : GraphTransformer("WhereDummyDq") {}
19+
20+
private:
21+
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
22+
23+
bool SatisfyCondition(const Graph& graph, const Node& node) const;
24+
Status InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const;
25+
};
26+
} // namespace onnxruntime

onnxruntime/test/optimizer/qdq_transformer_test.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "core/mlas/inc/mlas.h"
1313
#include "core/optimizer/double_qdq_pairs_remover.h"
1414
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
15+
#include "core/optimizer/qdq_transformer/where_dummy_dq.h"
1516
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
1617
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
1718
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
@@ -3220,6 +3221,79 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) {
32203221
test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128
32213222
}
32223223

3224+
template <typename ScaleType, typename ZpType>
3225+
void TestWhereWithDqInput(bool is_dq_1,
3226+
bool is_dq_2,
3227+
int expected_num_where,
3228+
int expected_num_dq,
3229+
int expected_num_q,
3230+
bool expected_modified) {
3231+
auto& logger = DefaultLoggingManager().DefaultLogger();
3232+
Model model("WhereDummyDqTester", false, logger);
3233+
Graph& graph = model.MainGraph();
3234+
ModelTestBuilder builder(graph);
3235+
3236+
NodeArg* where_in1 = nullptr;
3237+
NodeArg* where_in2 = nullptr;
3238+
if (is_dq_1) {
3239+
// DQ
3240+
auto* dq_Input = builder.MakeInput<ZpType>({4, 3, 32}, 0.0, 1.0);
3241+
auto* dq_scale = builder.MakeInitializer<ScaleType>({}, 0.0, 1.0);
3242+
auto* dq_zp = builder.MakeInitializer<ZpType>({}, 0.0, 1.0);
3243+
where_in1 = builder.MakeIntermediate();
3244+
builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in1});
3245+
} else {
3246+
where_in1 = builder.MakeInitializer<float>({}, 0.0, 1.0);
3247+
}
3248+
if (is_dq_2) {
3249+
// DQ
3250+
auto* dq_Input = builder.MakeInput<ZpType>({4, 3, 32}, 0.0, 1.0);
3251+
auto* dq_scale = builder.MakeInitializer<ScaleType>({}, 0.0, 1.0);
3252+
auto* dq_zp = builder.MakeInitializer<ZpType>({}, 0.0, 1.0);
3253+
where_in2 = builder.MakeIntermediate();
3254+
builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in2});
3255+
} else {
3256+
where_in2 = builder.MakeInitializer<float>({}, 0.0, 1.0);
3257+
}
3258+
3259+
// Where
3260+
auto* where_cond = builder.MakeInputBool({4, 3, 32});
3261+
auto* where_out = builder.MakeIntermediate();
3262+
builder.AddNode("Where", {where_cond, where_in1, where_in2}, {where_out});
3263+
3264+
// Q
3265+
auto* q_scale = builder.MakeInitializer<float>({}, 0.0, 1.0);
3266+
auto* q_zp = builder.MakeInitializer<uint16_t>({}, 0.0, 1.0);
3267+
auto* q_out = builder.MakeOutput();
3268+
builder.AddNode("QuantizeLinear", {where_out, q_scale, q_zp}, {q_out});
3269+
3270+
builder.SetGraphOutputs();
3271+
ASSERT_STATUS_OK(graph.Resolve());
3272+
3273+
auto where_optimizer = std::make_unique<WhereDummyDq>();
3274+
bool modified = false;
3275+
ASSERT_STATUS_OK(where_optimizer->Apply(graph, modified, logger));
3276+
3277+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
3278+
ASSERT_EQ(op_to_count["Where"], expected_num_where);
3279+
ASSERT_EQ(op_to_count["DequantizeLinear"], expected_num_dq);
3280+
ASSERT_EQ(op_to_count["QuantizeLinear"], expected_num_q);
3281+
ASSERT_EQ(modified, expected_modified);
3282+
3283+
return;
3284+
};
3285+
3286+
TEST(QDQTransformerTests, WhereDummyDqTest) {
3287+
TestWhereWithDqInput<float, uint8_t>(true, true, 1, 2, 1, false);
3288+
TestWhereWithDqInput<float, uint8_t>(true, false, 1, 2, 1, true);
3289+
TestWhereWithDqInput<float, uint8_t>(false, true, 1, 2, 1, true);
3290+
TestWhereWithDqInput<float, uint8_t>(false, false, 1, 0, 1, false);
3291+
TestWhereWithDqInput<float, uint16_t>(true, true, 1, 2, 1, false);
3292+
TestWhereWithDqInput<float, uint16_t>(true, false, 1, 2, 1, true);
3293+
TestWhereWithDqInput<float, uint16_t>(false, true, 1, 2, 1, true);
3294+
TestWhereWithDqInput<float, uint16_t>(false, false, 1, 0, 1, false);
3295+
}
3296+
32233297
TEST(QDQTransformerTests, Concat) {
32243298
auto test_case = [&](const std::vector<std::vector<int64_t>>& input_shapes,
32253299
int64_t axis,

0 commit comments

Comments
 (0)