@@ -170,12 +170,12 @@ def match(self, node):
170170 s1 = node .weights ['scale' ].data_unquantized
171171 b1 = node .weights ['bias' ].data_unquantized
172172 scale_compatible = (
173- (prev_node .get_attr ('scale_quantizer' ) is None and node .get_attr ('scale_quantizer' ) is None )
173+ (prev_node .get_attr ('scale_quantizer' ) is None or node .get_attr ('scale_quantizer' ) is None )
174174 or (s0 == np .ones_like (s0 )).all ()
175175 or (s1 == np .ones_like (s1 )).all ()
176176 )
177177 bias_compatible = (
178- (prev_node .get_attr ('bias_quantizer' ) is None and node .get_attr ('bias_quantizer' ) is None )
178+ (prev_node .get_attr ('bias_quantizer' ) is None or node .get_attr ('bias_quantizer' ) is None )
179179 or (b0 == np .zeros_like (b0 )).all ()
180180 or (b1 == np .zeros_like (b1 )).all ()
181181 )
@@ -195,26 +195,24 @@ def transform(self, model, node):
195195 # if len(node_map[node.outputs[0]]) > 1:
196196 # return False
197197
198- # only merge if the types are integer or fixed
199- if (
200- not isinstance (prev_node .weights ['scale' ].type .precision , (IntegerPrecisionType , FixedPrecisionType ))
201- or not isinstance (prev_node .weights ['bias' ].type .precision , (IntegerPrecisionType , FixedPrecisionType ))
202- or not isinstance (node .weights ['scale' ].type .precision , (IntegerPrecisionType , FixedPrecisionType ))
203- or not isinstance (node .weights ['bias' ].type .precision , (IntegerPrecisionType , FixedPrecisionType ))
204- ):
205- return False
206-
207198 s0 = prev_node .weights ['scale' ].data_unquantized
208199 b0 = prev_node .weights ['bias' ].data_unquantized
209200 s1 = node .weights ['scale' ].data_unquantized
210201 b1 = node .weights ['bias' ].data_unquantized
211202
212- s_quantizer = (
213- node .get_attr ('scale_quantizer' ) if (s0 == np .ones_like (s0 )).all () else prev_node .get_attr ('scale_quantizer' )
214- )
215- b_quantizer = (
216- node .get_attr ('bias_quantizer' ) if (b0 == np .zeros_like (b0 )).all () else prev_node .get_attr ('bias_quantizer' )
217- )
203+ if (s0 == np .ones_like (s0 )).all ():
204+ s_quantizer = node .get_attr ('scale_quantizer' )
205+ elif (s1 == np .ones_like (s1 )).all ():
206+ s_quantizer = prev_node .get_attr ('scale_quantizer' )
207+ else :
208+ s_quantizer = None
209+
210+ if (b0 == np .ones_like (b0 )).all ():
211+ b_quantizer = node .get_attr ('bias_quantizer' )
212+ elif (b1 == np .ones_like (b1 )).all ():
213+ b_quantizer = prev_node .get_attr ('bias_quantizer' )
214+ else :
215+ b_quantizer = None
218216
219217 node .set_attr ('scale_quantizer' , s_quantizer )
220218 node .set_attr ('bias_quantizer' , b_quantizer )
0 commit comments