43
43
_out_scale_op_list = [
44
44
"conv2d" , "depthwise_conv2d" , "mul" , "matmul" , "relu" , "leaky_relu" ,
45
45
"relu6" , "sigmoid" , "tanh" , "prelu" , "swish" , "softmax" , "batch_norm" ,
46
- "elementwise_add" , "pool2d" , "reshape2" , "transpose2"
46
+ "elementwise_add" , "pool2d" , "reshape2" , "transpose2" , "concat"
47
47
]
48
48
49
49
# list op real input and output names, to avoid processing input such as AxisTensor.
83
83
"swish" : [["X" ], ["Out" ]],
84
84
"dropout" : [["X" ], ["Out" ]],
85
85
"batch_norm" : [["X" ], ["Y" ]],
86
- "sigmoid" : [["X" ], ["Y " ]],
86
+ "sigmoid" : [["X" ], ["Out " ]],
87
87
}
88
88
89
89
@@ -1156,20 +1156,27 @@ def apply(self, graph):
1156
1156
assert isinstance (graph ,
1157
1157
IrGraph ), 'graph must be the instance of IrGraph.'
1158
1158
self ._is_test = graph .is_test ()
1159
- ops = graph .all_op_nodes ()
1160
- for op_node in ops :
1161
- name = op_node .name ()
1162
- if name in self ._teller_set :
1163
- if len (op_node .output_arg_names ()) != 1 :
1164
- continue
1165
- in_node = graph ._find_node_by_name (
1166
- op_node .outputs , op_node .output_arg_names ()[0 ])
1159
+ target_ops = []
1160
+ for op in graph .all_op_nodes ():
1161
+ if op .name () in self ._teller_set :
1162
+ target_ops .append (op )
1163
+ for op in target_ops :
1164
+ for output_var_name in _get_op_output_var_names (op ):
1165
+ in_node = graph ._find_node_by_name (op .outputs , output_var_name )
1167
1166
out_node = graph .create_var_node_from_desc (in_node .var ())
1168
1167
scale_node = graph .create_persistable_node (
1169
1168
name = self ._scale_name (in_node .name ()),
1170
1169
var_type = core .VarDesc .VarType .LOD_TENSOR ,
1171
1170
shape = [1 ],
1172
1171
var_dtype = in_node .dtype ())
1172
+ data_type = 'float64' if in_node .dtype () \
1173
+ == core .VarDesc .VarType .FP64 else 'float32'
1174
+ _init_var_node (
1175
+ scale_node ,
1176
+ np .ones (
1177
+ [1 ], dtype = data_type ),
1178
+ self ._scope ,
1179
+ self ._place )
1173
1180
ins = {'X' : in_node }
1174
1181
outs = {'Out' : out_node , 'OutScale' : scale_node }
1175
1182
if not self ._is_test :
@@ -1178,8 +1185,6 @@ def apply(self, graph):
1178
1185
var_type = core .VarDesc .VarType .LOD_TENSOR ,
1179
1186
var_dtype = in_node .dtype (),
1180
1187
shape = [1 ])
1181
- data_type = 'float64' if in_node .dtype (
1182
- ) == core .VarDesc .VarType .FP64 else 'float32'
1183
1188
_init_var_node (
1184
1189
state_in_node ,
1185
1190
np .ones (
@@ -1257,13 +1262,13 @@ def apply(self, graph):
1257
1262
"""
1258
1263
assert isinstance (graph ,
1259
1264
IrGraph ), 'graph must be the instance of IrGraph.'
1260
- ops = graph .all_op_nodes ()
1261
- for op_node in ops :
1262
- name = op_node .name ()
1263
- if name in self . _teller_set :
1264
- if len (op_node . output_arg_names ()) != 1 :
1265
- continue
1266
- scale_name = self ._scale_name (op_node . output_arg_names () [0 ])
1265
+ op_nodes = graph .all_op_nodes ()
1266
+ for op_node in op_nodes :
1267
+ if op_node .name () in self . _teller_set :
1268
+ output_var_name = _get_op_output_var_names ( op_node )
1269
+ assert len (output_var_name ) == 1 , "Only support collecting " \
1270
+ "output for op that only has an activation output for now."
1271
+ scale_name = self ._scale_name (output_var_name [0 ])
1267
1272
scale_v = np .array (
1268
1273
self ._scope .find_var (scale_name ).get_tensor ())[0 ]
1269
1274
op_node .op ()._set_attr ("out_threshold" , float (scale_v ))
0 commit comments