1+ //  Copyright (c) Microsoft Corporation. All rights reserved.
2+ //  Licensed under the MIT License.
3+ 
4+ #include  " core/optimizer/qdq_transformer/where_dummy_dq.h" 
5+ 
6+ #include  " core/framework/tensorprotoutils.h" 
7+ #include  " core/common/common.h" 
8+ #include  " core/util/qmath.h" 
9+ #include  " core/graph/graph_utils.h" 
10+ #include  " core/graph/graph_viewer.h" 
11+ #include  " core/optimizer/initializer.h" 
12+ #include  " core/optimizer/utils.h" 
13+ #include  " core/optimizer/qdq_transformer/qdq_util.h" 
14+ 
15+ namespace  onnxruntime  {
16+ bool  WhereDummyDq::SatisfyCondition (const  Graph& graph, const  Node& node) const  {
17+   if  (!(node.OpType () == " Where"  )) {
18+     return  false ;
19+   }
20+   const  auto & where_inputs = node.InputDefs ();
21+   const  Node* parent_node_1 = graph.GetProducerNode (where_inputs[1 ]->Name ());
22+   const  Node* parent_node_2 = graph.GetProducerNode (where_inputs[2 ]->Name ());
23+ 
24+   bool  is_p1_dq = (parent_node_1 && parent_node_1->OpType () == QDQ::DQOpName);
25+   bool  is_p2_dq = (parent_node_2 && parent_node_2->OpType () == QDQ::DQOpName);
26+ 
27+   //  WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input
28+   if  (is_p1_dq && !parent_node_2) {
29+     return  (where_inputs[2 ]->Shape ()->dim_size () == 0 );
30+   }
31+   if  (!parent_node_1 && is_p2_dq) {
32+     return  (where_inputs[1 ]->Shape ()->dim_size () == 0 );
33+   }
34+   return  false ;
35+ }
36+ 
37+ Status WhereDummyDq::InsertDummyDQ (Node& node, Graph& graph, bool & modified, const  logging::Logger& logger) const  {
38+   const  auto & where_inputs = node.InputDefs ();
39+   const  Node* parent_node_1 = graph.GetProducerNode (where_inputs[1 ]->Name ());
40+   const  Node* parent_node_2 = graph.GetProducerNode (where_inputs[2 ]->Name ());
41+ 
42+   //  With SatisfyCondition, we must have one DQ and one initializer
43+   const  Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2;
44+   int  const_idx = parent_node_1 ? 2  : 1 ;
45+ 
46+   const  ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr ;
47+   graph.GetInitializedTensor (dq_node->InputDefs ()[1 ]->Name (), dq_node_scale_proto);
48+   const  ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr ;
49+   graph.GetInitializedTensor (dq_node->InputDefs ()[2 ]->Name (), dq_node_zp_proto);
50+ 
51+   //  Dummy data initializer.
52+   ONNX_NAMESPACE::TensorProto dummy_data_proto;
53+   dummy_data_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_data"  ));
54+   //  Set data type to dq node's zp dtype
55+   dummy_data_proto.set_data_type (dq_node_zp_proto->data_type ());
56+ 
57+   //  Dummy zero point initializer.
58+   ONNX_NAMESPACE::TensorProto dummy_zp_proto;
59+   dummy_zp_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_zp"  ));
60+   dummy_zp_proto.set_data_type (dq_node_zp_proto->data_type ());
61+ 
62+   switch  (dummy_zp_proto.data_type ()) {
63+     case  ONNX_NAMESPACE::TensorProto_DataType_INT8: {
64+       int8_t  zp = 0 ;
65+       int8_t  dummy_data = 1 ;
66+       dummy_zp_proto.set_raw_data (&zp, 1 );
67+       dummy_data_proto.set_raw_data (&dummy_data, 1 );
68+       break ;
69+     }
70+     case  ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
71+       uint8_t  zp = 0 ;
72+       uint8_t  dummy_data = 1 ;
73+       dummy_zp_proto.set_raw_data (&zp, 1 );
74+       dummy_data_proto.set_raw_data (&dummy_data, 1 );
75+       break ;
76+     }
77+     case  ONNX_NAMESPACE::TensorProto_DataType_INT16: {
78+       int16_t  zp = 0 ;
79+       int16_t  dummy_data = 1 ;
80+       dummy_zp_proto.set_raw_data (&zp, 2 );
81+       dummy_data_proto.set_raw_data (&dummy_data, 2 );
82+       break ;
83+     }
84+     case  ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
85+       uint16_t  zp = 0 ;
86+       uint16_t  dummy_data = 1 ;
87+       dummy_zp_proto.set_raw_data (&zp, 2 );
88+       dummy_data_proto.set_raw_data (&dummy_data, 2 );
89+       break ;
90+     }
91+     default :
92+       LOGS (logger, WARNING) << " Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16"  ;
93+       return  Status::OK ();
94+   }
95+ 
96+   //  Set dummy scale to the original value
97+   const  ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr ;
98+   graph.GetInitializedTensor (where_inputs[const_idx]->Name (), const_node_data_proto);
99+   Initializer initializer (graph, *const_node_data_proto, graph.ModelPath ());
100+   if  (dq_node_scale_proto->data_type () != const_node_data_proto->data_type ()) {
101+     //  WhereDummyDq fills the const value to the dummy DQ's scale
102+     LOGS (logger, WARNING) << " Currently only support existing DQ's scale with same datatype as scalar"  ;
103+     return  Status::OK ();
104+   }
105+ 
106+   //  Dummy scale initializer.
107+   ONNX_NAMESPACE::TensorProto dummy_scale_proto;
108+   dummy_scale_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_scale"  ));
109+   dummy_scale_proto.set_data_type (dq_node_scale_proto->data_type ());
110+   switch  (initializer.data_type ()) {
111+     case  ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
112+       float * where_const_scalar = initializer.data <float >();
113+       dummy_scale_proto.set_raw_data (where_const_scalar, sizeof (float ));
114+       break ;
115+     }
116+     default :
117+       LOGS (logger, WARNING) << " Currently support scalar with FLOAT"  ;
118+       return  Status::OK ();
119+   }
120+ 
121+   //  Start editing the graph
122+   NodeArg& dummy_data_arg = graph_utils::AddInitializerWithExternalData (graph, dummy_data_proto);
123+   NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithExternalData (graph, dummy_scale_proto);
124+   NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithExternalData (graph, dummy_zp_proto);
125+ 
126+   ONNX_NAMESPACE::TypeProto dummy_dq_type_proto = utils::TypeProtoFromTensorProto (*const_node_data_proto);
127+   dummy_dq_type_proto.mutable_tensor_type ()->set_elem_type (const_node_data_proto->data_type ());
128+   NodeArg& dummy_dq_arg =
129+       graph.GetOrCreateNodeArg (graph.GenerateNodeArgName (node.Name () + " _dummy_dq"  ), &dummy_dq_type_proto);
130+   Node& dummy_dq_node =
131+       graph.AddNode (
132+           graph.GenerateNodeArgName (node.Name () + " _dummy_dq"  ),
133+           QDQ::DQOpName,
134+           " DeQuantizeLinear from WhereDummyDq GraphTransformer"  ,
135+           {&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg},
136+           {&dummy_dq_arg},
137+           nullptr ,
138+           dq_node->Domain ());
139+ 
140+   node.MutableInputDefs ()[const_idx] = &dummy_dq_arg;
141+   if  (graph.GetConsumerNodes (where_inputs[const_idx]->Name ()).size () == 0 ) {
142+     graph.RemoveInitializedTensor (where_inputs[const_idx]->Name ());
143+   }
144+   graph.AddEdge (dummy_dq_node.Index (), node.Index (), 0 , const_idx);
145+   modified = true ;
146+ 
147+   return  Status::OK ();
148+ }
149+ 
150+ Status WhereDummyDq::ApplyImpl (Graph& graph, bool & modified, int  graph_level, const  logging::Logger& logger) const  {
151+   const  GraphViewer graph_viewer{graph};
152+   const  auto & node_indices = graph_viewer.GetNodesInTopologicalOrder ();
153+   for  (const  auto  node_idx : node_indices) {
154+     auto * node_ptr = graph.GetNode (node_idx);
155+     if  (!node_ptr) {
156+       continue ;
157+     }
158+ 
159+     Node& node = *node_ptr;
160+     ORT_RETURN_IF_ERROR (Recurse (node, modified, graph_level, logger));
161+ 
162+     if  (this ->SatisfyCondition (graph, node)) {
163+       ORT_RETURN_IF_ERROR (WhereDummyDq::InsertDummyDQ (node, graph, modified, logger));
164+     }
165+   }
166+ 
167+   return  Status::OK ();
168+ }
169+ }  //  namespace onnxruntime
0 commit comments