4
4
"""
5
5
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random_uniform op
6
6
"""
7
+ import numpy as np
7
8
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
8
- from tf2onnx import utils
9
+ from tf2onnx import utils , handler
9
10
10
11
11
12
# pylint: disable=missing-docstring
@@ -29,10 +30,10 @@ def rewrite_random_uniform(g, ops):
29
30
# max is on input 0
30
31
tmax = input2 .inputs [0 ].get_tensor_value ()
31
32
tmin = input2 .inputs [1 ].get_tensor_value ()
32
-
33
- new_node = create_onnx_random_uniform_op (g , tmax , tmin , ru_op , output )
33
+ to_delete = list ( set ( match . get_nodes ()))
34
+ new_node = create_onnx_random_uniform_op (g , tmax , tmin , ru_op , output , to_delete )
34
35
g .replace_all_inputs (ops , output .output [0 ], new_node .output [0 ])
35
- for n in set ( match . get_nodes ()) :
36
+ for n in to_delete :
36
37
g .remove_node (n .name )
37
38
38
39
return ops
@@ -59,25 +60,50 @@ def rewrite_random_uniform_fold_const(g, ops):
59
60
tmax_minus_tmin = mul .inputs [1 ].get_tensor_value ()
60
61
tmin = output .inputs [1 ].get_tensor_value ()
61
62
tmax = tmin + tmax_minus_tmin
62
- new_node = create_onnx_random_uniform_op (g , tmax , tmin , ru_op , output )
63
+ to_delete = list (set (match .get_nodes ()))
64
+ new_node = create_onnx_random_uniform_op (g , tmax , tmin , ru_op , output , to_delete )
63
65
g .replace_all_inputs (ops , output .output [0 ], new_node .output [0 ])
64
- for n in set ( match . get_nodes ()) :
66
+ for n in to_delete :
65
67
g .remove_node (n .name )
66
68
67
69
return ops
68
70
69
71
70
- def create_onnx_random_uniform_op (g , tmax , tmin , ru_op , output ):
72
+ def create_onnx_random_uniform_op (g , tmax , tmin , ru_op , output , to_delete ):
71
73
dtype = g .get_dtype (output .output [0 ])
72
74
op_name = utils .make_name ("RandomUniform" )
73
- if ru_op .inputs [0 ].type == "Shape" :
74
- shape_node = ru_op .inputs [0 ]
75
- new_node = g .make_node ("RandomUniformLike" , inputs = [shape_node .input [0 ]], name = op_name ,
76
- attr = {"low" : tmin , "high" : tmax , "dtype" : dtype },
77
- shapes = shape_node .output_shapes , dtypes = [dtype ])
78
- else :
79
- shape = g .get_shape (output .output [0 ])
75
+ shape_node = ru_op .inputs [0 ]
76
+ shape = g .get_shape (output .output [0 ])
77
+ if shape_node .is_const ():
78
+ # if the tensorflow input (aka the shape) is const we can use the RandomUniform op
80
79
new_node = g .make_node ("RandomUniform" , [], name = op_name ,
81
80
attr = {"low" : tmin , "high" : tmax , "dtype" : dtype , "shape" : shape },
82
81
shapes = [shape ], dtypes = [dtype ])
82
+ else :
83
+ if shape_node .type == "Shape" :
84
+ # if shape is dynamic - in tensorflow shape comes as tensor VALUE,
85
+ # in onnx RandomUniformLike finds takes the shape from the tensor itself.
86
+ # In many cases there is a shape op in tensorflow before RandomUniform and
87
+ # to make that work for onnx we just need to remove the shape op.
88
+ new_node = g .make_node ("RandomUniformLike" , inputs = [shape_node .input [0 ]], name = op_name ,
89
+ attr = {"low" : tmin , "high" : tmax , "dtype" : dtype },
90
+ shapes = shape , dtypes = [dtype ])
91
+ else :
92
+ # if the shape is calculated we need to create a tensor so RandomUniformLike
93
+ # can take the shape from there. Pre opset9 this is somewhat hacky because there is
94
+ # no real fill op in onnx. In general this is not going to help performance but the tensors
95
+ # created are expected to be small.
96
+
97
+ # tell the caller to not delete the shape node
98
+ to_delete .remove (shape_node )
99
+ # create a fill op with the shape of the value of the input tensor
100
+ zero = g .make_const (utils .make_name ("zero" ), np .zeros ((), dtype = np .float32 ))
101
+ fill_node = g .make_node ("Fill" , inputs = [shape_node .output [0 ], zero .name ],
102
+ shapes = shape , dtypes = [dtype ])
103
+ func , _ = handler .tf_op .find_effective_op ("Fill" )
104
+ func (g , fill_node )
105
+ # and use RandomUniformLike to create the random tensor
106
+ new_node = g .make_node ("RandomUniformLike" , inputs = [fill_node .output [0 ]], name = op_name ,
107
+ attr = {"low" : tmin , "high" : tmax , "dtype" : dtype },
108
+ shapes = shape , dtypes = [dtype ])
83
109
return new_node
0 commit comments