Skip to content

Commit 05b8e0c

Browse files
committed
fix bug in fill_op
1 parent 603da0a commit 05b8e0c

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,11 +398,11 @@ def _maxmin_handler(self, trans, node):
398398
for i in all_other_inputs:
399399
target_node = self._g.get_node_by_output(i)
400400
numpy_val = target_node.get_tensor_value(as_list=False)
401-
rank = np.rank(numpy_val)
401+
rank = numpy_val.ndim
402402
if rank == 4:
403403
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
404404
target_node.set_tensor_value(transposed_val)
405-
elif rank == 1: # scalar
405+
elif rank == 1: # scalar
406406
# do nothing
407407
pass
408408
else:

tf2onnx/tfonnx.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,17 +1456,14 @@ def fill_op(ctx, node, name, args):
14561456
# both shape and value in tensorflow are passed as tensor.
14571457
# In onnx the value is an attribute so we need to fetch the value as const which
14581458
# sooner or later will be a problem for tensorflow-onnx.
1459-
shape = ctx.get_shape(node.output[0])
1460-
utils.make_sure(all(i >= 0 for i in shape), "shape attr should not be less than zero")
1461-
value = node.inputs[1].get_tensor_value()
1462-
value_proto = numpy_helper.from_array(node.inputs[1].get_tensor_value(as_list=False))
1463-
dtype = value_proto.data_type
1464-
# onnx spec says value MUST be float.
1465-
node.set_attr("value", float(value))
1466-
node.set_attr("shape", shape)
1467-
node.set_attr("dtype", dtype)
1468-
del node.input[:]
1469-
return node
1459+
# ConstantOfShape in onnxruntime only support int64, so insert cast op
1460+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.INT64)
1461+
dtype = ctx.get_dtype(node.output[0])
1462+
value = np.array([node.inputs[1].get_tensor_value()]).astype(utils.ONNX_TO_NUMPY_DTYPE[dtype])
1463+
value_proto = numpy_helper.from_array(value)
1464+
node.set_attr("value", value_proto)
1465+
del node.input[1]
1466+
return [node, cast_node]
14701467

14711468

14721469
def reverse_op8(ctx, node, name, args):

0 commit comments

Comments
 (0)