@@ -50,18 +50,18 @@ def rewrite_dropout(g, ops):
50
50
matcher = GraphMatcher (pattern , allow_reorder = True )
51
51
match_results = list (matcher .match_ops (ops ))
52
52
for match in match_results :
53
- inputs2 = match .get_op ('input2' )
54
- inputs3 = match .get_op ('input3' )
55
- if inputs3 . type == "Const" :
56
- ratio = inputs3 .get_tensor_value ()
53
+ input2 = match .get_op ('input2' )
54
+ input3 = match .get_op ('input3' )
55
+ if input3 . is_const () :
56
+ ratio = input3 .get_tensor_value ()
57
57
else :
58
58
# If the ratio isn't constant, set it to 0
59
- logger .error ("Dropout node has non-constant ratio. Using ratio=0.0" )
59
+ logger .warning ("Dropout node has non-constant ratio. Using ratio=0.0" )
60
60
ratio = 0.0
61
- if inputs2 .inputs [0 ].type == "RealDiv" :
62
- data = inputs2 .input [1 ]
61
+ if input2 .inputs [0 ].type == "RealDiv" :
62
+ data = input2 .input [1 ]
63
63
else :
64
- data = inputs2 .input [0 ]
64
+ data = input2 .input [0 ]
65
65
# TODO(tomwildenhain): replace dropout node with identity if ratio is 0
66
66
outputs = match .get_op ('outputs' )
67
67
op_name = utils .make_name ("Dropout" )
@@ -72,10 +72,19 @@ def rewrite_dropout(g, ops):
72
72
outputs = [out_name ],
73
73
name = op_name ,
74
74
attr = {"ratio" : ratio },
75
- shapes = [g .get_shape (inputs2 .input [0 ])],
76
- dtypes = [g .get_dtype (inputs2 .input [0 ])]
75
+ shapes = [g .get_shape (input2 .input [0 ])],
76
+ dtypes = [g .get_dtype (input2 .input [0 ])]
77
77
)
78
78
g .replace_all_inputs (ops , outputs .output [0 ], new_node .output [0 ])
79
- g .safe_remove_nodes (match .get_nodes ())
79
+ nodes_to_remove = []
80
+ for node in match .get_nodes ():
81
+ if node .name != input3 .name :
82
+ nodes_to_remove .append (node )
83
+ if g .safe_to_remove_nodes (nodes_to_remove ):
84
+ for n in nodes_to_remove :
85
+ g .remove_node (n .name )
86
+ else :
87
+ logger .warning ("Nodes replaced by dropout node cannot be removed because intermediate results are "
88
+ "referenced elsewhere in graph" )
80
89
81
90
return ops
0 commit comments