Skip to content

Commit 5e48449

Browse files
Fix RandomUniform/RandomNormal rewriters for non-const inputs (#1710)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 2a6504f commit 5e48449

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tf2onnx/rewriter/random_normal_rewriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def rewrite_random_normal(g, ops):
1616
pattern1 = \
1717
OpTypePattern('Add', name='output', inputs=[
1818
OpTypePattern('Mul', name='input2', inputs=[
19-
OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
20-
]), "*"
19+
OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "Const|ConstV2"
20+
]), "Const|ConstV2"
2121
])
2222

2323
pattern2 = \

tf2onnx/rewriter/random_uniform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def rewrite_random_uniform(g, ops):
1818
OpTypePattern('Add', name='output', inputs=[
1919
OpTypePattern('Mul', inputs=[
2020
OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
21-
OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
21+
OpTypePattern('Sub', name='input2', inputs=["Const|ConstV2", "Const|ConstV2"]),
2222
]), None
2323
])
2424

@@ -45,9 +45,9 @@ def rewrite_random_uniform_fold_const(g, ops):
4545
OpTypePattern('Add', name='output', inputs=[
4646
OpTypePattern('Mul', name='mul', inputs=[
4747
OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
48-
None,
48+
"Const|ConstV2",
4949
]),
50-
None,
50+
"Const|ConstV2",
5151
])
5252

5353
matcher = GraphMatcher(pattern)

0 commit comments

Comments
 (0)