|
5 | 5 | tf2onnx.rewriter - rewrite tensorflow subgraph to onnx dropout op
|
6 | 6 | """
|
7 | 7 |
|
| 8 | +import numpy as np |
8 | 9 | from tf2onnx import utils
|
9 | 10 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
|
10 | 11 | from tf2onnx import logging
|
@@ -52,39 +53,51 @@ def rewrite_dropout(g, ops):
|
52 | 53 | for match in match_results:
|
53 | 54 | input2 = match.get_op('input2')
|
54 | 55 | input3 = match.get_op('input3')
|
55 |
| - if input3.is_const(): |
56 |
| - ratio = input3.get_tensor_value() |
57 |
| - else: |
58 |
| - # If the ratio isn't constant, set it to 0 |
59 |
| - logger.warning("Dropout node has non-constant ratio. Using ratio=0.0") |
60 |
| - ratio = 0.0 |
61 |
| - if input2.inputs[0].type == "RealDiv": |
62 |
| - data = input2.input[1] |
63 |
| - else: |
64 |
| - data = input2.input[0] |
65 |
| - # TODO(tomwildenhain): replace dropout node with identity if ratio is 0 |
66 | 56 | outputs = match.get_op('outputs')
|
| 57 | + |
| 58 | + if not input3.is_scalar(): |
| 59 | + logger.warning("Dropout pattern rooted at %s does not have a " |
| 60 | + "constant ratio and cannot be replaced.", outputs.name) |
| 61 | + continue |
| 62 | + ratio = input3.get_tensor_value() |
| 63 | + |
| 64 | + if input2.inputs[0].is_scalar(): |
| 65 | + data = input2.inputs[1] |
| 66 | + scaling_constant = input2.inputs[0].get_tensor_value() |
| 67 | + elif input2.inputs[1].is_scalar(): |
| 68 | + data = input2.inputs[0] |
| 69 | + scaling_constant = input2.inputs[1].get_tensor_value() |
| 70 | + else: |
| 71 | + logger.warning("Could not find scaling constant for dropout pattern rooted at %s. " |
| 72 | + "The pattern will not be replaced with an ONNX dropout node.", outputs.name) |
| 73 | + continue |
| 74 | + |
| 75 | + #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node |
| 76 | + if not np.allclose([1], [scaling_constant * (1 - ratio)]): |
| 77 | + logger.warning("Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout " |
| 78 | + "ratio %f. The pattern will not be replaced with an ONNX dropout node.", |
| 79 | + scaling_constant, outputs.name, ratio) |
| 80 | + continue |
| 81 | + |
| 82 | + nodes_to_remove = [n for n in match.get_nodes() if n.name != input3.name] |
| 83 | + if not g.is_safe_to_remove_nodes(nodes_to_remove, [outputs.output[0]]): |
| 84 | + logger.warning("Nodes in dropout pattern rooted at %s cannot be removed because intermediate results " |
| 85 | + "of some nodes are referenced elsewhere in graph.", outputs.name) |
| 86 | + continue |
| 87 | + |
67 | 88 | op_name = utils.make_name("Dropout")
|
68 | 89 | out_name = utils.port_name(op_name)
|
69 | 90 | new_node = g.make_node(
|
70 | 91 | "Dropout",
|
71 |
| - [data], |
| 92 | + inputs=[data.output[0]], |
72 | 93 | outputs=[out_name],
|
73 | 94 | name=op_name,
|
74 | 95 | attr={"ratio": ratio},
|
75 |
| - shapes=[g.get_shape(input2.input[0])], |
76 |
| - dtypes=[g.get_dtype(input2.input[0])] |
| 96 | + shapes=[g.get_shape(data.output[0])], |
| 97 | + dtypes=[g.get_dtype(data.output[0])] |
77 | 98 | )
|
78 | 99 | g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
|
79 |
| - nodes_to_remove = [] |
80 |
| - for node in match.get_nodes(): |
81 |
| - if node.name != input3.name: |
82 |
| - nodes_to_remove.append(node) |
83 |
| - if g.safe_to_remove_nodes(nodes_to_remove): |
84 |
| - for n in nodes_to_remove: |
85 |
| - g.remove_node(n.name) |
86 |
| - else: |
87 |
| - logger.warning("Nodes replaced by dropout node cannot be removed because intermediate results are " |
88 |
| - "referenced elsewhere in graph") |
| 100 | + for n in nodes_to_remove: |
| 101 | + g.remove_node(n.name) |
89 | 102 |
|
90 | 103 | return ops
|
0 commit comments