5
5
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx dropout op
6
6
"""
7
7
8
+ import numpy as np
8
9
from tf2onnx import utils
9
10
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
11
+ from tf2onnx import logging
12
+
13
+ logger = logging .getLogger (__name__ )
10
14
11
15
12
16
# pylint: disable=missing-docstring
@@ -18,7 +22,7 @@ def rewrite_dropout(g, ops):
18
22
OpTypePattern ('RealDiv' , name = "input2" ),
19
23
OpTypePattern ('Floor' , inputs = [
20
24
OpTypePattern ('Add' , inputs = [
21
- OpTypePattern (None , name = "input3" ),
25
+ OpTypePattern ("*" , name = "input3" ),
22
26
OpTypePattern ('RandomUniform|RandomUniformLike' ),
23
27
])
24
28
]),
@@ -28,7 +32,7 @@ def rewrite_dropout(g, ops):
28
32
OpTypePattern ("Cast" , inputs = [
29
33
OpTypePattern ("GreaterEqual" , inputs = [
30
34
OpTypePattern ("RandomUniform|RandomUniformLike" ),
31
- OpTypePattern (None , name = "input3" )
35
+ OpTypePattern ("*" , name = "input3" )
32
36
])
33
37
])
34
38
]),
@@ -37,7 +41,7 @@ def rewrite_dropout(g, ops):
37
41
OpTypePattern ("Cast" , inputs = [
38
42
OpTypePattern ("GreaterEqual" , inputs = [
39
43
OpTypePattern ("RandomUniform|RandomUniformLike" ),
40
- OpTypePattern (None , name = "input3" )
44
+ OpTypePattern ("*" , name = "input3" )
41
45
])
42
46
]),
43
47
OpTypePattern ("Mul" , name = "input2" ),
@@ -47,30 +51,53 @@ def rewrite_dropout(g, ops):
47
51
matcher = GraphMatcher (pattern , allow_reorder = True )
48
52
match_results = list (matcher .match_ops (ops ))
49
53
for match in match_results :
50
- inputs2 = match .get_op ('input2' )
51
- if inputs2 .inputs [0 ].type == "RealDiv" :
52
- data = inputs2 .input [1 ]
53
- else :
54
- data = inputs2 .input [0 ]
54
+ input2 = match .get_op ('input2' )
55
+ input3 = match .get_op ('input3' )
55
56
outputs = match .get_op ('outputs' )
57
+
58
+ if not input3 .is_scalar ():
59
+ logger .warning ("Dropout pattern rooted at %s does not have a "
60
+ "constant ratio and cannot be replaced." , outputs .name )
61
+ continue
62
+ ratio = input3 .get_tensor_value ()
63
+
64
+ if input2 .inputs [0 ].is_scalar ():
65
+ data = input2 .inputs [1 ]
66
+ scaling_constant = input2 .inputs [0 ].get_tensor_value ()
67
+ elif input2 .inputs [1 ].is_scalar ():
68
+ data = input2 .inputs [0 ]
69
+ scaling_constant = input2 .inputs [1 ].get_tensor_value ()
70
+ else :
71
+ logger .warning ("Could not find scaling constant for dropout pattern rooted at %s. "
72
+ "The pattern will not be replaced with an ONNX dropout node." , outputs .name )
73
+ continue
74
+
75
+ #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node
76
+ if not np .allclose ([1 ], [scaling_constant * (1 - ratio )]):
77
+ logger .warning ("Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout "
78
+ "ratio %f. The pattern will not be replaced with an ONNX dropout node." ,
79
+ scaling_constant , outputs .name , ratio )
80
+ continue
81
+
82
+ nodes_to_remove = [n for n in match .get_nodes () if n .name != input3 .name ]
83
+ if not g .is_safe_to_remove_nodes (nodes_to_remove , [outputs .output [0 ]]):
84
+ logger .warning ("Nodes in dropout pattern rooted at %s cannot be removed because intermediate results "
85
+ "of some nodes are referenced elsewhere in graph." , outputs .name )
86
+ continue
87
+
56
88
op_name = utils .make_name ("Dropout" )
57
89
out_name = utils .port_name (op_name )
58
90
new_node = g .make_node (
59
91
"Dropout" ,
60
- [data ],
92
+ inputs = [data . output [ 0 ] ],
61
93
outputs = [out_name ],
62
94
name = op_name ,
63
- attr = {"ratio" : 1.0 },
64
- shapes = [g .get_shape (inputs2 . input [0 ])],
65
- dtypes = [g .get_dtype (inputs2 . input [0 ])]
95
+ attr = {"ratio" : ratio },
96
+ shapes = [g .get_shape (data . output [0 ])],
97
+ dtypes = [g .get_dtype (data . output [0 ])]
66
98
)
67
99
g .replace_all_inputs (ops , outputs .output [0 ], new_node .output [0 ])
68
- g .safe_remove_nodes (match .get_nodes ())
69
-
70
- # remove dropout if its ratio is 1.0
71
- for node in g .get_nodes ():
72
- if node .type == "Dropout" and node .get_attr ("ratio" ).f == 1.0 :
73
- g .replace_all_inputs (g .get_nodes (), node .output [0 ], node .input [0 ])
74
- g .remove_node (node .name )
100
+ for n in nodes_to_remove :
101
+ g .remove_node (n .name )
75
102
76
103
return ops
0 commit comments