Skip to content

Commit 0473788

Browse files
committed
fix issue with randomop
1 parent c421da8 commit 0473788

File tree

4 files changed

+35
-18
lines changed

4 files changed

+35
-18
lines changed

benchmarks/profile_conversion_time.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def spy_convert(graph_def, model):
3333
tf.import_graph_def(graph_def=graph_def, name='')
3434

3535
def spy_convert_in():
36-
return tfonnx.process_tf_graph(tf_graph=graph,
37-
input_names=[model.input.name],
38-
output_names=[model.output.name])
36+
return tfonnx.process_tf_graph(
37+
tf_graph=graph, input_names=[model.input.name],
38+
output_names=[model.output.name])
3939

4040
spy_convert_in()
4141

@@ -55,19 +55,21 @@ def convert(graph_def, model):
5555
spy_convert(graph_def, model)
5656

5757

58-
def profile(profiler="pyinstrument", name="MobileNet", show_all=False,
58+
def profile(profiler="none", name="MobileNet", show_all=False,
5959
module='tf.keras'):
6060
"""
6161
Profiles the conversion of a model.
62-
63-
:param profiler: one among spy, pyinstrument, cProfile
62+
63+
:param profiler: one among none, spy, pyinstrument, cProfile
6464
:param name: model to profile, MobileNet, EfficientNetB2
6565
:param show_all: use by pyinstrument to show all functions
6666
"""
6767
print("create(%r, %r, %r)" % (profiler, name, module))
6868
graph_def, model = create(name, module)
6969
print("profile(%r, %r, %r)" % (profiler, name, module))
70-
if profiler == "spy":
70+
if profiler == 'none':
71+
convert(graph_def, model)
72+
elif profiler == "spy":
7173
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
7274
convert(graph_def, model)
7375
elif profiler == "pyinstrument":

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def get_tensor_value(self, as_list=True):
282282
when as_list=True, return 1, type is <class 'int'>.
283283
"""
284284
if not self.is_const():
285-
raise ValueError("get tensor value: {} must be Const".format(self.name))
285+
raise ValueError("get tensor value: '{}' must be Const".format(self.name))
286286

287287
t = self.get_attr("value")
288288
if t:

tf2onnx/onnx_opset/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def version_1(cls, ctx, node, **kwargs):
3838
# const we make it an attribute.
3939
seed = node.get_attr("seed")
4040
node.set_attr("seed", float(seed.f))
41-
if len(node.input) > 0:
41+
if len(node.input) > 0 and node.inputs[0].is_const():
4242
shape = node.inputs[0].get_tensor_value()
4343
ctx.remove_input(node, node.input[0], 0)
4444
node.set_attr("shape", shape)

tf2onnx/tfonnx.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def tensorflow_onnx_mapping(g, ops_mapping):
257257
func(g, node, **kwargs)
258258
node.skip_conversion = True
259259
except Exception as ex:
260-
logger.error("Failed to convert node %s\n%s", node.name, node.summary, exc_info=1)
260+
logger.error("Failed to convert node %r (fct=%r)\n%r",
261+
node.name, func, node.summary, exc_info=1)
261262
exceptions.append(ex)
262263

263264
return mapped_op, unmapped_op, exceptions
@@ -450,14 +451,28 @@ def compat_handler(ctx, node, **kwargs):
450451

451452
# pre-processing graph rewrites
452453
# bi-directional re-writer should be placed after single directional re-writer
453-
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,
454-
rewrite_random_uniform, rewrite_random_uniform_fold_const,
455-
rewrite_random_normal, rewrite_dropout, rewrite_eye,
456-
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
457-
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
458-
rewrite_single_direction_gru, rewrite_bi_direction_gru,
459-
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
460-
rewrite_biasadd_with_conv2d, rewrite_gemm
454+
rewriters = [# single directional
455+
rewrite_constant_fold,
456+
rewrite_quantize_and_dequantize,
457+
rewrite_transpose,
458+
rewrite_flatten,
459+
rewrite_random_uniform,
460+
rewrite_random_uniform_fold_const,
461+
rewrite_random_normal,
462+
rewrite_dropout,
463+
rewrite_eye,
464+
rewrite_leakyrelu,
465+
rewrite_thresholded_relu,
466+
rewrite_conv2d_with_pad,
467+
rewrite_single_direction_lstm,
468+
# bi-directional
469+
rewrite_bi_direction_lstm,
470+
rewrite_single_direction_gru,
471+
rewrite_bi_direction_gru,
472+
rewrite_custom_rnn_cell,
473+
rewrite_generic_loop, rewrite_cond,
474+
rewrite_biasadd_with_conv2d,
475+
rewrite_gemm,
461476
]
462477

463478
if custom_rewriter is not None:

0 commit comments

Comments
 (0)