7
7
8
8
from tf2onnx import utils
9
9
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
10
+ from tf2onnx import logging
11
+
12
+ logger = logging .getLogger (__name__ )
10
13
11
14
12
15
# pylint: disable=missing-docstring
@@ -18,7 +21,7 @@ def rewrite_dropout(g, ops):
18
21
OpTypePattern ('RealDiv' , name = "input2" ),
19
22
OpTypePattern ('Floor' , inputs = [
20
23
OpTypePattern ('Add' , inputs = [
21
- OpTypePattern (None , name = "input3" ),
24
+ OpTypePattern ("*" , name = "input3" ),
22
25
OpTypePattern ('RandomUniform|RandomUniformLike' ),
23
26
])
24
27
]),
@@ -28,7 +31,7 @@ def rewrite_dropout(g, ops):
28
31
OpTypePattern ("Cast" , inputs = [
29
32
OpTypePattern ("GreaterEqual" , inputs = [
30
33
OpTypePattern ("RandomUniform|RandomUniformLike" ),
31
- OpTypePattern (None , name = "input3" )
34
+ OpTypePattern ("*" , name = "input3" )
32
35
])
33
36
])
34
37
]),
@@ -37,7 +40,7 @@ def rewrite_dropout(g, ops):
37
40
OpTypePattern ("Cast" , inputs = [
38
41
OpTypePattern ("GreaterEqual" , inputs = [
39
42
OpTypePattern ("RandomUniform|RandomUniformLike" ),
40
- OpTypePattern (None , name = "input3" )
43
+ OpTypePattern ("*" , name = "input3" )
41
44
])
42
45
]),
43
46
OpTypePattern ("Mul" , name = "input2" ),
@@ -48,10 +51,18 @@ def rewrite_dropout(g, ops):
48
51
match_results = list (matcher .match_ops (ops ))
49
52
for match in match_results :
50
53
inputs2 = match .get_op ('input2' )
54
+ inputs3 = match .get_op ('input3' )
55
+ if inputs3 .type == "Const" :
56
+ ratio = inputs3 .get_tensor_value ()
57
+ else :
58
+ # If the ratio isn't constant, set it to 0
59
+ logger .error ("Dropout node has non-constant ratio. Using ratio=0.0" )
60
+ ratio = 0.0
51
61
if inputs2 .inputs [0 ].type == "RealDiv" :
52
62
data = inputs2 .input [1 ]
53
63
else :
54
64
data = inputs2 .input [0 ]
65
+ # TODO(tomwildenhain): replace dropout node with identity if ratio is 0
55
66
outputs = match .get_op ('outputs' )
56
67
op_name = utils .make_name ("Dropout" )
57
68
out_name = utils .port_name (op_name )
@@ -60,17 +71,11 @@ def rewrite_dropout(g, ops):
60
71
[data ],
61
72
outputs = [out_name ],
62
73
name = op_name ,
63
- attr = {"ratio" : 1.0 },
74
+ attr = {"ratio" : ratio },
64
75
shapes = [g .get_shape (inputs2 .input [0 ])],
65
76
dtypes = [g .get_dtype (inputs2 .input [0 ])]
66
77
)
67
78
g .replace_all_inputs (ops , outputs .output [0 ], new_node .output [0 ])
68
79
g .safe_remove_nodes (match .get_nodes ())
69
80
70
- # remove dropout if its ratio is 1.0
71
- for node in g .get_nodes ():
72
- if node .type == "Dropout" and node .get_attr ("ratio" ).f == 1.0 :
73
- g .replace_all_inputs (g .get_nodes (), node .output [0 ], node .input [0 ])
74
- g .remove_node (node .name )
75
-
76
81
return ops
0 commit comments