Skip to content

Commit 02fe9fb

Browse files
[mlir][python] Wrappers for scf.index_switch
The C++ index switch op has utilies for getCaseBlock(int i) and getDefaultBlock(), so these have been added. Optional builder args have been added for the default case and each switch case. The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable.
1 parent 6bad2d1 commit 02fe9fb

File tree

2 files changed

+142
-5
lines changed

2 files changed

+142
-5
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ._scf_ops_gen import *
77
from ._scf_ops_gen import _Dialect
88
from .arith import constant
9+
import builtins
910

1011
try:
1112
from ..ir import *
@@ -19,7 +20,6 @@
1920

2021
from typing import List, Optional, Sequence, Tuple, Union
2122

22-
2323
@_ods_cext.register_operation(_Dialect, replace=True)
2424
class ForOp(ForOp):
2525
"""Specialization for the SCF for op class."""
@@ -254,3 +254,76 @@ def for_(
254254
yield iv, iter_args[0], for_op.results[0]
255255
else:
256256
yield iv
257+
258+
@_ods_cext.register_operation(_Dialect, replace=True)
259+
class IndexSwitchOp(IndexSwitchOp):
260+
__doc__ = IndexSwitchOp.__doc__
261+
262+
def __init__(
263+
self,
264+
results_,
265+
arg,
266+
cases,
267+
case_body_builder=None,
268+
default_body_builder=None,
269+
loc=None,
270+
ip=None,
271+
):
272+
cases = DenseI64ArrayAttr.get(cases)
273+
super().__init__(
274+
results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
275+
)
276+
for region in self.regions:
277+
region.blocks.append()
278+
279+
if default_body_builder is not None:
280+
with InsertionPoint(self.default_block):
281+
default_body_builder(self)
282+
283+
if case_body_builder is not None:
284+
for i, case in enumerate(cases):
285+
with InsertionPoint(self.case_block(i)):
286+
case_body_builder(self, i, self.cases[i])
287+
288+
@builtins.property
289+
def default_region(self) -> Region:
290+
return self.regions[0]
291+
292+
@builtins.property
293+
def default_block(self) -> Block:
294+
return self.default_region.blocks[0]
295+
296+
@builtins.property
297+
def case_regions(self) -> Sequence[Region]:
298+
return [self.regions[1 + i] for i in range(len(self.cases))]
299+
300+
def case_region(self, i: int) -> Region:
301+
return self.case_regions[i]
302+
303+
@builtins.property
304+
def case_blocks(self) -> Sequence[Block]:
305+
return [region.blocks[0] for region in self.case_regions]
306+
307+
def case_block(self, i: int) -> Block:
308+
return self.case_regions[i].blocks[0]
309+
310+
def index_switch(
311+
results_,
312+
arg,
313+
cases,
314+
case_body_builder=None,
315+
default_body_builder=None,
316+
loc=None,
317+
ip=None,
318+
) -> Union[OpResult, OpResultList, IndexSwitchOp]:
319+
op = IndexSwitchOp(
320+
results_=results_,
321+
arg=arg,
322+
cases=cases,
323+
case_body_builder=case_body_builder,
324+
default_body_builder=default_body_builder,
325+
loc=loc,
326+
ip=ip,
327+
)
328+
results = op.results
329+
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

mlir/test/python/dialects/scf.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
from mlir.dialects import arith
5-
from mlir.dialects import func
6-
from mlir.dialects import memref
7-
from mlir.dialects import scf
4+
from mlir.extras import types as T
5+
from mlir.dialects import (
6+
arith,
7+
func,
8+
memref,
9+
scf,
10+
cf,
11+
)
812
from mlir.passmanager import PassManager
913

1014

@@ -355,3 +359,63 @@ def simple_if_else(cond):
355359
# CHECK: scf.yield %[[TWO]], %[[THREE]]
356360
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
357361
# CHECK: return
362+
363+
364+
@constructAndPrintInModule
365+
def testIndexSwitch():
366+
367+
i32 = T.i32()
368+
@func.FuncOp.from_py_func(T.index(), results=[i32])
369+
def index_switch(index):
370+
c1 = arith.constant(i32, 1)
371+
c0 = arith.constant(i32, 0)
372+
value = arith.constant(i32, 5)
373+
switch_op = scf.IndexSwitchOp([i32], index, range(3))
374+
375+
assert switch_op.regions[0] == switch_op.default_region
376+
assert switch_op.regions[1] == switch_op.case_regions[0]
377+
assert switch_op.regions[1] == switch_op.case_region(0)
378+
assert len(switch_op.case_regions) == 3
379+
assert len(switch_op.regions) == 4
380+
381+
with InsertionPoint(switch_op.default_block):
382+
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
383+
scf.yield_([c1])
384+
385+
for i, block in enumerate(switch_op.case_blocks):
386+
with InsertionPoint(block):
387+
scf.YieldOp([arith.constant(i32, i)])
388+
389+
func.return_([switch_op.results[0]])
390+
391+
return index_switch
392+
393+
394+
@constructAndPrintInModule
395+
def testIndexSwitchWithBodyBuilders():
396+
397+
i32 = T.i32()
398+
@func.FuncOp.from_py_func(T.index(), results=[i32])
399+
def index_switch(index):
400+
c1 = arith.constant(i32, 1)
401+
c0 = arith.constant(i32, 0)
402+
value = arith.constant(i32, 5)
403+
404+
def default_body_builder(switch_op):
405+
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
406+
scf.yield_([c1])
407+
408+
def case_body_builder(switch_op, case_index: int, case_value: int):
409+
scf.YieldOp([arith.constant(i32, case_value)])
410+
411+
result = scf.index_switch(
412+
results_=[i32],
413+
arg=index,
414+
cases=range(3),
415+
case_body_builder=case_body_builder,
416+
default_body_builder=default_body_builder,
417+
)
418+
419+
func.return_([result])
420+
421+
return index_switch

0 commit comments

Comments
 (0)