Skip to content

Commit 5f918ab

Browse files
inonbefatcat-z
andauthored
A fix for seed attribute in the keras random normal generator (#2126)
* A fix for seed field in the tf.keras random normal generator. The seed field was not passing to the converted onnx model due to that it exists in seed2 attribute instead of seed and its type is Integer and not float. --------- Signed-off-by: inonbe <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 535f74c commit 5f918ab

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tf2onnx/onnx_opset/generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def version_1(cls, ctx, node, **kwargs):
6060
# in the rewriter does not trigger. grappler will send the random uniform
6161
# with shape as input so we need to pickup the input here and if the shape is
6262
# const we make it an attribute.
63-
seed = node.get_attr("seed")
64-
node.set_attr("seed", float(seed.f))
63+
seed = node.get_attr("seed2")
64+
node.set_attr("seed", float(seed.i))
6565
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9", node.type)
6666
shape = node.inputs[0].get_tensor_value()
6767
ctx.remove_input(node, node.input[0], 0)
@@ -88,8 +88,8 @@ def version_9(cls, ctx, node, **kwargs):
8888
if node.inputs[0].is_const():
8989
cls.version_1(ctx, node, **kwargs)
9090
else:
91-
seed = node.get_attr("seed")
92-
node.set_attr("seed", float(seed.f))
91+
seed = node.get_attr("seed2")
92+
node.set_attr("seed", float(seed.i))
9393
cast_node = ctx.make_node("Cast", [node.input[0]], attr={'to': onnx_pb.TensorProto.INT64})
9494
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
9595
inputs = node.input.copy()

0 commit comments

Comments
 (0)