Skip to content

Commit b143359

Browse files
authored
fix find_ops (#74)
1 parent 64a8f14 commit b143359

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

mlir/extras/util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
IntegerType,
2424
Location,
2525
MemRefType,
26+
Module,
2627
OpView,
2728
Operation,
2829
RankedTensorType,
@@ -104,7 +105,10 @@ def shlib_prefix():
104105
return shlib_pref
105106

106107

107-
def find_ops(op, pred: Callable[[OpView], bool], single=False):
108+
def find_ops(op, pred: Callable[[OpView, Operation, Module], bool], single=False):
109+
if isinstance(op, (OpView, Module)):
110+
op = op.operation
111+
108112
matching = []
109113

110114
def find(op):
@@ -118,7 +122,7 @@ def find(op):
118122
find(o)
119123

120124
find(op)
121-
if single:
125+
if single and matching:
122126
matching = matching[0]
123127
return matching
124128

0 commit comments

Comments
 (0)