Skip to content

Commit 87baf8f

Browse files
authored
A couple of minor fixes on rewrite rules (#2432)
* The recently introduced scatter-nd elimination optimization requires the `remove_nodes=False` option to be more effective (which was somehow lost in the initial implementation). * Add a missing import statement for future annotations in the `fuse_relus_clips.py` file Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 138cb30 commit 87baf8f

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

onnxscript/rewriter/fuse_relus_clips.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
- Clip(Clip(X)) -> Clip
88
"""
99

10+
from __future__ import annotations
11+
1012
import abc
1113

1214
import numpy as np

onnxscript/rewriter/redundant_scatter_nd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626

2727
class ScatterAllDynamic(orp.RewriteRuleClassBase):
28+
def __init__(self):
29+
super().__init__(remove_nodes=False)
30+
2831
def pattern(self, op, data, axis, transposed_data, updates):
2932
# Construct update-indices spanning an entire axis:
3033
shape = op.Shape(data, start=0)

0 commit comments

Comments
 (0)