@@ -1983,99 +1983,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
1983
1983
return concat_out;
1984
1984
}
1985
1985
1986
- void patterns::QuantDequantOpFuse::operator ()(PDNode *quant_op_input,
1987
- const std::string &op_type,
1988
- const std::string &weight_name,
1989
- int times,
1990
- const std::string &quant_type,
1991
- const std::string &dequant_type) {
1992
- int kNumFields = 5 ;
1993
- const int kQuantizedWeightOffset = 0 ;
1994
- const int kQuantizedOpOffset = 1 ;
1995
- const int kQuantizedOpOutOffset = 2 ;
1996
- const int kDequantOpOffset = 3 ;
1997
- const int kDequantOpOutOffset = 4 ;
1998
- const int kDequantOpWeightScaleOffset = 5 ;
1999
-
2000
- // the quant op always be one.
2001
- auto quant_op_in_scale = pattern->NewNode (GetNodeName (" quant_op_in_scale" ))
1986
+ void patterns::DeleteQuantOpFuse::operator ()(PDNode *input_act_node,
1987
+ const std::string &quant_type) {
1988
+ auto *input_scale_node = pattern->NewNode (GetNodeName (" input_scale_node" ))
2002
1989
->assert_is_op_input (quant_type, " InScale" )
2003
1990
->AsInput ();
2004
- auto quant_op =
2005
- pattern->NewNode (GetNodeName (" quant_op" ))->assert_is_op (quant_type);
2006
-
2007
- PDNode *quant_op_out_scale = nullptr ;
1991
+ auto *quant_node =
1992
+ pattern->NewNode (GetNodeName (" quant_node" ))->assert_is_op (quant_type);
1993
+ auto *output_scale_node = pattern->NewNode (GetNodeName (" output_scale_node" ))
1994
+ ->assert_is_op_output (quant_type, " OutScale" )
1995
+ ->AsOutput ();
1996
+ auto *output_act_node = pattern->NewNode (GetNodeName (" output_act_node" ))
1997
+ ->assert_is_op_output (quant_type, " Out" )
1998
+ ->AsOutput ();
1999
+ quant_node->LinksFrom ({input_scale_node, input_act_node});
2000
+ output_scale_node->LinksFrom ({quant_node});
2001
+ output_act_node->LinksFrom ({quant_node});
2002
+ }
2003
+
2004
+ void patterns::DequantOpFuse::operator ()(PDNode *quantized_op_input,
2005
+ const std::string &quantized_op_type,
2006
+ const std::string &dequant_type,
2007
+ const std::string &weight_name) {
2008
+ auto *quantized_op_weight =
2009
+ pattern->NewNode (GetNodeName (" quantized_op_weight" ))
2010
+ ->assert_is_op_input (quantized_op_type, weight_name)
2011
+ ->AsInput ();
2012
+ auto *quantized_op = pattern->NewNode (GetNodeName (" quantized_op" ))
2013
+ ->assert_is_op (quantized_op_type);
2014
+ auto *quantized_op_out = pattern->NewNode (GetNodeName (" quantized_op_out" ))
2015
+ ->assert_is_op_output (quantized_op_type)
2016
+ ->assert_is_op_input (dequant_type, " X" );
2017
+ auto *dequant_op =
2018
+ pattern->NewNode (GetNodeName (" dequant_op" ))->assert_is_op (dequant_type);
2019
+ auto *dequant_op_out = pattern->NewNode (GetNodeName (" dequant_op_out" ))
2020
+ ->assert_is_op_output (dequant_type, " Out" )
2021
+ ->AsOutput ();
2022
+ PDNode *dequant_channel_scale = nullptr ;
2008
2023
if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
2009
- kNumFields += 1 ;
2010
- quant_op_out_scale = pattern->NewNode (GetNodeName (" quant_op_out_scale" ))
2011
- ->assert_is_op_output (quant_type, " OutScale" )
2012
- ->assert_is_op_nth_input (dequant_type, " Scales" , 1 )
2013
- ->AsIntermediate ();
2014
- } else {
2015
- quant_op_out_scale = pattern->NewNode (GetNodeName (" quant_op_out_scale" ))
2016
- ->assert_is_op_output (quant_type, " OutScale" )
2017
- ->assert_is_op_input (dequant_type, " Scale" )
2018
- ->AsIntermediate ();
2024
+ dequant_channel_scale =
2025
+ pattern->NewNode (GetNodeName (" dequant_channel_scale" ))
2026
+ ->assert_is_op_nth_input (dequant_type, " Scales" , 0 )
2027
+ ->AsInput ();
2019
2028
}
2029
+ quantized_op->LinksFrom ({quantized_op_input, quantized_op_weight});
2030
+ quantized_op_out->LinksFrom ({quantized_op});
2020
2031
2021
- auto quant_op_out = pattern->NewNode (GetNodeName (" quant_op_out" ))
2022
- ->assert_is_op_output (quant_type, " Out" )
2023
- ->assert_is_op_input (op_type)
2024
- ->AsIntermediate ();
2025
-
2026
- // there are 'times' quantized and dequant op
2027
- std::vector<PDNode *> nodes;
2028
- for (int i = 0 ; i < times; i++) {
2029
- nodes.push_back (
2030
- pattern->NewNode (GetNodeName (" quantized_op_weight" ) + std::to_string (i))
2031
- ->assert_is_op_input (op_type, weight_name)
2032
- ->AsInput ());
2033
- nodes.push_back (
2034
- pattern->NewNode (GetNodeName (" quantized_op" ) + std::to_string (i))
2035
- ->assert_is_op (op_type));
2036
-
2037
- nodes.push_back (
2038
- pattern->NewNode (GetNodeName (" quantized_op_out" ) + std::to_string (i))
2039
- ->assert_is_op_output (op_type)
2040
- ->assert_is_op_input (dequant_type, " X" )
2041
- ->AsIntermediate ());
2042
-
2043
- nodes.push_back (
2044
- pattern->NewNode (GetNodeName (" dequant_op" ) + std::to_string (i))
2045
- ->assert_is_op (dequant_type));
2046
-
2047
- nodes.push_back (
2048
- pattern->NewNode (GetNodeName (" dequant_op_out" ) + std::to_string (i))
2049
- ->assert_is_op_output (dequant_type, " Out" )
2050
- ->AsOutput ());
2051
-
2052
- if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
2053
- nodes.push_back (pattern
2054
- ->NewNode (GetNodeName (" dequant_channel_scale" ) +
2055
- std::to_string (i))
2056
- ->assert_is_op_nth_input (dequant_type, " Scales" , 0 )
2057
- ->AsInput ());
2058
- }
2059
- }
2060
-
2061
- quant_op->LinksFrom ({quant_op_input, quant_op_in_scale});
2062
- quant_op_out->LinksFrom ({quant_op});
2063
- for (int i = 0 ; i < times; i++) {
2064
- nodes[i * kNumFields + kQuantizedOpOffset ]->LinksFrom (
2065
- {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset ]});
2066
- nodes[i * kNumFields + kQuantizedOpOutOffset ]->LinksFrom (
2067
- {nodes[i * kNumFields + kQuantizedOpOffset ]});
2068
- if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
2069
- nodes[i * kNumFields + kDequantOpOffset ]->LinksFrom (
2070
- {nodes[i * kNumFields + kQuantizedOpOutOffset ], quant_op_out_scale,
2071
- nodes[i * kNumFields + kDequantOpWeightScaleOffset ]});
2072
- } else {
2073
- nodes[i * kNumFields + kDequantOpOffset ]->LinksFrom (
2074
- {nodes[i * kNumFields + kQuantizedOpOutOffset ], quant_op_out_scale});
2075
- }
2076
- nodes[i * kNumFields + kDequantOpOutOffset ]->LinksFrom (
2077
- {nodes[i * kNumFields + kDequantOpOffset ]});
2032
+ if (dequant_type == " fake_channel_wise_dequantize_max_abs" ) {
2033
+ dequant_op->LinksFrom ({quantized_op_out, dequant_channel_scale});
2034
+ } else {
2035
+ dequant_op->LinksFrom ({quantized_op_out});
2078
2036
}
2037
+ dequant_op_out->LinksFrom ({dequant_op});
2079
2038
}
2080
2039
2081
2040
void patterns::ShuffleChannelPattern::operator ()(PDNode *reshape1_in) {
0 commit comments