11
11
12
12
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
13
13
14
+ # key is op_type, value is the function to compute outputs
15
+ # the schema of function is: inputs are(node, graph), output is a list of constant values.
14
16
_func_map = {}
15
17
16
18
@@ -43,7 +45,7 @@ def _optimize_at_current_graph_level(self, graph):
43
45
44
46
@staticmethod
45
47
def _should_skip (node ):
46
- # only support onnx official op for now, other op such as contrib op not supported for now
48
+ # only support onnx official op for now, op in other domain is not supported for now
47
49
if not utils .is_onnx_domain (node .domain ):
48
50
return True
49
51
@@ -63,10 +65,10 @@ def _fold_node(self, node, graph):
63
65
if self ._all_inputs_are_const (node .inputs ) and not self ._is_graph_output (node , graph ):
64
66
process_func = _func_map .get (node .type , None )
65
67
if process_func :
66
- const_val_after_trans = process_func (node , graph )
67
- self ._replace_node_with_const (node , graph , const_val_after_trans )
68
+ const_outputs = process_func (node , graph )
69
+ self ._replace_node_with_const (node , graph , const_outputs )
68
70
return True
69
- self .log .warning ("need to add function to fold op %s whose op_type is %s" , node .name , node .type )
71
+ self .log .debug ("need to add function to fold op %s whose op_type is %s" , node .name , node .type )
70
72
return False
71
73
72
74
@staticmethod
0 commit comments