Skip to content

Commit 52b574d

Browse files
committed
refine onnx_to_numpy conversion a bit
1 parent d523ace commit 52b574d

File tree

5 files changed

+19
-21
lines changed

5 files changed

+19
-21
lines changed

tf2onnx/function/sparse_softmax_cross_entropy_with_logits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def sparse_softmax_cross_entropy_with_logits_op(ctx, node, name, args):
2929
logit_dtype = ctx.get_dtype(logit_name)
3030
utils.make_sure(logit_dtype, "Dtype of {} is None".format(logit_name))
3131

32-
dtype = utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]
32+
dtype = utils.map_onnx_to_numpy_type(logit_dtype)
3333
eye = np.eye(depth).astype(dtype)
3434
const_name = utils.make_name("const_eye")
3535
const_eye = ctx.make_const(name=const_name, np_val=eye)
@@ -82,7 +82,7 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
8282
make_gathernd(ctx, log_softmax.output[0], indices_with_id.output[0], gathernd_output,
8383
gathernd_name, logit_dtype)
8484
const_name = utils.make_name("const_negative_one")
85-
const_negative_one = ctx.make_const(const_name, np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
85+
const_negative_one = ctx.make_const(const_name, np.array(-1).astype(utils.map_onnx_to_numpy_type(logit_dtype)))
8686
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], gathernd_output])
8787
shapes = node.output_shapes
8888
dtypes = node.output_dtypes

tf2onnx/rewriter/rnn_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def get_weights_from_const_node(g, node):
198198

199199
if temp and temp.type == 'Const':
200200
val = temp.get_tensor_value(as_list=False)
201-
dtype = utils.ONNX_TO_NUMPY_DTYPE[g.get_dtype(temp.output[0])]
201+
dtype = utils.map_onnx_to_numpy_type(g.get_dtype(temp.output[0]))
202202
log.debug("found weights %s", temp.name)
203203
else:
204204
log.debug("weight node seems not to be Const, skip, node name is %s", temp.name)

tf2onnx/rewriter/unit_rewriter_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def _workaround_fill_ch_init_node(self, initializer_input_id, rnn_props):
420420
return None
421421

422422
fill_val = node.inputs[1].get_tensor_value()
423-
fill_val_dtype = utils.ONNX_TO_NUMPY_DTYPE[self.g.get_dtype(node.input[1])]
423+
fill_val_dtype = utils.map_onnx_to_numpy_type(self.g.get_dtype(node.input[1]))
424424

425425
# this must be int64, since Concat's input data type must be consistent.
426426
num_direction_node = self.g.make_const(utils.make_name("Const"), np.array([1], dtype=np.float32))

tf2onnx/tfonnx.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def relu6_op(ctx, node, name, args):
604604
# since onnx does not have relu6, compose it with multiple ops.
605605
old_output = node.output[0]
606606
dtype = ctx.get_dtype(node.input[0])
607-
dtype = utils.ONNX_TO_NUMPY_DTYPE[dtype] if dtype else np.float32
607+
dtype = utils.map_onnx_to_numpy_type(dtype) if dtype else np.float32
608608
shape = ctx.get_shape(node.input[0])
609609
if -1 in shape:
610610
# if the shape has unknown dims we need to do something like this for opset < 8 (=no broadcast for min/max):
@@ -651,7 +651,7 @@ def relu6_op8(ctx, node, name, args):
651651
# since onnx does not have relu6, compose it with multiple ops.
652652
old_output = node.output[0]
653653
dtype = ctx.get_dtype(node.input[0])
654-
dtype = utils.ONNX_TO_NUMPY_DTYPE[dtype] if dtype else np.float32
654+
dtype = utils.map_onnx_to_numpy_type(dtype) if dtype else np.float32
655655
node.type = "Max"
656656
# const tensor 6
657657
six_name = utils.make_name(node.name)
@@ -687,7 +687,7 @@ def sign_op(ctx, node, name, args):
687687
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
688688
if node_dtype in [onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
689689
raise ValueError("dtype " + node_dtype + " is not supported in onnx for now")
690-
input_tensor_type = utils.ONNX_TO_NUMPY_DTYPE[node_dtype]
690+
input_tensor_type = utils.map_onnx_to_numpy_type(node_dtype)
691691
zero_name = utils.make_name("{}_zero".format(node.name))
692692
ctx.make_const(zero_name, np.array(0, dtype=input_tensor_type))
693693
greater_node = ctx.make_node("Greater", [node.input[0], zero_name])
@@ -1280,7 +1280,7 @@ def fused_batchnorm_op7(ctx, node, name, args):
12801280
scale_shape = ctx.get_shape(node.input[1])
12811281
mean_shape = ctx.get_shape(node.input[3])
12821282
var_shape = ctx.get_shape(node.input[4])
1283-
val_type = utils.ONNX_TO_NUMPY_DTYPE[ctx.get_dtype(node.input[1])]
1283+
val_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))
12841284

12851285
if mean_shape != scale_shape:
12861286
new_mean_value = np.array(np.resize(node.inputs[3].get_tensor_value(as_list=False), scale_shape),
@@ -1387,11 +1387,11 @@ def fill_op(ctx, node, name, args):
13871387
# In onnx the value is an attribute so we need to fetch the value as const which
13881388
# sooner or later will be a problem for tensorflow-onnx.
13891389
# ConstantOfShape in onnxruntime only support int64, so insert cast op
1390-
input_dtype_is_int64 = utils.ONNX_TO_NUMPY_DTYPE[ctx.get_dtype(node.input[0])] == np.int64
1390+
input_dtype_is_int64 = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0])) == np.int64
13911391
if not input_dtype_is_int64:
13921392
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.INT64)
13931393
dtype = ctx.get_dtype(node.output[0])
1394-
value = np.array([node.inputs[1].get_tensor_value()]).astype(utils.ONNX_TO_NUMPY_DTYPE[dtype])
1394+
value = np.array([node.inputs[1].get_tensor_value()]).astype(utils.map_onnx_to_numpy_type(dtype))
13951395
value_proto = numpy_helper.from_array(value)
13961396
node.set_attr("value", value_proto)
13971397
del node.input[1]
@@ -1601,7 +1601,7 @@ def zeroslike_op(ctx, node, name, args):
16011601
# when params "dtype" used, tf will call another op "Fill" instead, so Cast is not needed here.
16021602
input_dtype = ctx.get_dtype(node.input[0])
16031603
node_name = utils.make_name("zero")
1604-
const_zero = ctx.make_const(node_name, np.array(0).astype(utils.ONNX_TO_NUMPY_DTYPE[input_dtype]))
1604+
const_zero = ctx.make_const(node_name, np.array(0).astype(utils.map_onnx_to_numpy_type(input_dtype)))
16051605
shapes = node.output_shapes
16061606
dtypes = node.output_dtypes
16071607
ctx.remove_node(name)
@@ -2051,15 +2051,15 @@ def rewrite_constant_fold(g, ops):
20512051
log.info("folding node type=%s, name=%s" % (op.type, op.name))
20522052
if op.type == "Cast":
20532053
dst = op.get_attr_int("to")
2054-
np_type = tf2onnx.utils.ONNX_TO_NUMPY_DTYPE[dst]
2054+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)
20552055
val = np.cast[np_type](*inputs)
20562056
elif op.type == "ConcatV2":
20572057
axis = inputs[-1]
20582058
values = inputs[:-1]
20592059
val = func(tuple(values), axis)
20602060
elif op.type == "ListDiff":
20612061
out_type = op.get_attr_int("out_idx")
2062-
np_type = tf2onnx.utils.ONNX_TO_NUMPY_DTYPE[out_type]
2062+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(out_type)
20632063
val = func(*inputs)
20642064
val = val.astype(np_type)
20652065
elif op.type in ["Pack"]:
@@ -2068,7 +2068,7 @@ def rewrite_constant_fold(g, ops):
20682068
val = func(inputs, axis=axis)
20692069
elif op.type == "Range":
20702070
dtype = op.get_attr_int("Tidx")
2071-
np_type = tf2onnx.utils.ONNX_TO_NUMPY_DTYPE[dtype]
2071+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dtype)
20722072
val = func(*inputs, dtype=np_type)
20732073
else:
20742074
val = func(*inputs)

tf2onnx/utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def tf_to_onnx_tensor(tensor, name=""):
138138
dims = [1]
139139
is_raw, data = get_tf_tensor_data(tensor)
140140
if not is_raw and len(data) == 1 and np.prod(dims) > 1:
141-
batch_data = np.zeros(dims, dtype=ONNX_TO_NUMPY_DTYPE[new_type])
141+
batch_data = np.zeros(dims, dtype=map_onnx_to_numpy_type(new_type))
142142
batch_data.fill(data[0])
143143
onnx_tensor = numpy_helper.from_array(batch_data, name=name)
144144
else:
@@ -207,6 +207,10 @@ def map_numpy_to_onnx_dtype(np_dtype):
207207
raise ValueError("unsupported dtype " + np_dtype + " for mapping")
208208

209209

210+
def map_onnx_to_numpy_type(onnx_type):
211+
return ONNX_TO_NUMPY_DTYPE[onnx_type]
212+
213+
210214
def node_name(name):
211215
"""Get node name without io#."""
212216
pos = name.find(":")
@@ -228,12 +232,6 @@ def port_name(name, nr=0):
228232
return name + ":" + str(nr)
229233

230234

231-
def make_onnx_identity(node_input, node_output, name=None):
232-
if name is None:
233-
name = make_name("identity")
234-
return helper.make_node("Identity", [node_input], [node_output], name=name)
235-
236-
237235
def make_onnx_inputs_outputs(name, elem_type, shape, **kwargs):
238236
"""Wrapper for creating onnx graph inputs or outputs
239237
name, # type: Text

0 commit comments

Comments
 (0)