Skip to content

Commit 9886e3e

Browse files
[mlir][python] Wrappers for scf.index_switch
The C++ index switch op has utilies for getCaseBlock(int i) and getDefaultBlock(), so these have been added. Optional builder args have been added for the default case and each switch case. The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable.
1 parent eb614cd commit 9886e3e

File tree

2 files changed

+198
-4
lines changed

2 files changed

+198
-4
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from ._scf_ops_gen import *
77
from ._scf_ops_gen import _Dialect
88
from .arith import constant
9+
import builtins
910

1011
try:
1112
from ..ir import *
1213
from ._ods_common import (
1314
get_op_result_or_value as _get_op_result_or_value,
1415
get_op_results_or_values as _get_op_results_or_values,
16+
get_op_result_or_op_results as _get_op_result_or_op_results,
1517
_cext as _ods_cext,
1618
)
1719
except ImportError as e:
@@ -254,3 +256,77 @@ def for_(
254256
yield iv, iter_args[0], for_op.results[0]
255257
else:
256258
yield iv
259+
260+
261+
@_ods_cext.register_operation(_Dialect, replace=True)
262+
class IndexSwitchOp(IndexSwitchOp):
263+
__doc__ = IndexSwitchOp.__doc__
264+
265+
def __init__(
266+
self,
267+
results_,
268+
arg,
269+
cases,
270+
case_body_builder=None,
271+
default_body_builder=None,
272+
loc=None,
273+
ip=None,
274+
):
275+
cases = DenseI64ArrayAttr.get(cases)
276+
super().__init__(
277+
results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
278+
)
279+
for region in self.regions:
280+
region.blocks.append()
281+
282+
if default_body_builder is not None:
283+
with InsertionPoint(self.default_block):
284+
default_body_builder(self)
285+
286+
if case_body_builder is not None:
287+
for i, case in enumerate(cases):
288+
with InsertionPoint(self.case_block(i)):
289+
case_body_builder(self, i, self.cases[i])
290+
291+
@builtins.property
292+
def default_region(self) -> Region:
293+
return self.regions[0]
294+
295+
@builtins.property
296+
def default_block(self) -> Block:
297+
return self.default_region.blocks[0]
298+
299+
@builtins.property
300+
def case_regions(self) -> Sequence[Region]:
301+
return [self.regions[1 + i] for i in range(len(self.cases))]
302+
303+
def case_region(self, i: int) -> Region:
304+
return self.case_regions[i]
305+
306+
@builtins.property
307+
def case_blocks(self) -> Sequence[Block]:
308+
return [region.blocks[0] for region in self.case_regions]
309+
310+
def case_block(self, i: int) -> Block:
311+
return self.case_regions[i].blocks[0]
312+
313+
314+
def index_switch(
315+
results_,
316+
arg,
317+
cases,
318+
case_body_builder=None,
319+
default_body_builder=None,
320+
loc=None,
321+
ip=None,
322+
) -> Union[OpResult, OpResultList, IndexSwitchOp]:
323+
op = IndexSwitchOp(
324+
results_=results_,
325+
arg=arg,
326+
cases=cases,
327+
case_body_builder=case_body_builder,
328+
default_body_builder=default_body_builder,
329+
loc=loc,
330+
ip=ip,
331+
)
332+
return _get_op_result_or_op_results(op)

mlir/test/python/dialects/scf.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
from mlir.dialects import arith
5-
from mlir.dialects import func
6-
from mlir.dialects import memref
7-
from mlir.dialects import scf
4+
from mlir.extras import types as T
5+
from mlir.dialects import (
6+
arith,
7+
func,
8+
memref,
9+
scf,
10+
cf,
11+
)
812
from mlir.passmanager import PassManager
913

1014

@@ -355,3 +359,117 @@ def simple_if_else(cond):
355359
# CHECK: scf.yield %[[TWO]], %[[THREE]]
356360
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
357361
# CHECK: return
362+
363+
364+
@constructAndPrintInModule
365+
def testIndexSwitch():
366+
i32 = T.i32()
367+
368+
@func.FuncOp.from_py_func(T.index(), results=[i32])
369+
def index_switch(index):
370+
c1 = arith.constant(i32, 1)
371+
c0 = arith.constant(i32, 0)
372+
value = arith.constant(i32, 5)
373+
switch_op = scf.IndexSwitchOp([i32], index, range(3))
374+
375+
assert switch_op.regions[0] == switch_op.default_region
376+
assert switch_op.regions[1] == switch_op.case_regions[0]
377+
assert switch_op.regions[1] == switch_op.case_region(0)
378+
assert len(switch_op.case_regions) == 3
379+
assert len(switch_op.regions) == 4
380+
381+
with InsertionPoint(switch_op.default_block):
382+
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
383+
scf.yield_([c1])
384+
385+
for i, block in enumerate(switch_op.case_blocks):
386+
with InsertionPoint(block):
387+
scf.yield_([arith.constant(i32, i)])
388+
389+
func.return_([switch_op.results[0]])
390+
391+
return index_switch
392+
393+
394+
# CHECK-LABEL: func.func @index_switch(
395+
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
396+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
397+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
398+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
399+
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
400+
# CHECK: case 0 {
401+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
402+
# CHECK: scf.yield %[[CONSTANT_3]] : i32
403+
# CHECK: }
404+
# CHECK: case 1 {
405+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
406+
# CHECK: scf.yield %[[CONSTANT_4]] : i32
407+
# CHECK: }
408+
# CHECK: case 2 {
409+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
410+
# CHECK: scf.yield %[[CONSTANT_5]] : i32
411+
# CHECK: }
412+
# CHECK: default {
413+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
414+
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
415+
# CHECK: scf.yield %[[CONSTANT_0]] : i32
416+
# CHECK: }
417+
# CHECK: return %[[INDEX_SWITCH_0]] : i32
418+
# CHECK: }
419+
420+
421+
@constructAndPrintInModule
422+
def testIndexSwitchWithBodyBuilders():
423+
i32 = T.i32()
424+
425+
@func.FuncOp.from_py_func(T.index(), results=[i32])
426+
def index_switch(index):
427+
c1 = arith.constant(i32, 1)
428+
c0 = arith.constant(i32, 0)
429+
value = arith.constant(i32, 5)
430+
431+
def default_body_builder(switch_op):
432+
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
433+
scf.yield_([c1])
434+
435+
def case_body_builder(switch_op, case_index: int, case_value: int):
436+
scf.yield_([arith.constant(i32, case_value)])
437+
438+
result = scf.index_switch(
439+
results_=[i32],
440+
arg=index,
441+
cases=range(3),
442+
case_body_builder=case_body_builder,
443+
default_body_builder=default_body_builder,
444+
)
445+
446+
func.return_([result])
447+
448+
return index_switch
449+
450+
451+
# CHECK-LABEL: func.func @index_switch(
452+
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
453+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
454+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
455+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
456+
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
457+
# CHECK: case 0 {
458+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
459+
# CHECK: scf.yield %[[CONSTANT_3]] : i32
460+
# CHECK: }
461+
# CHECK: case 1 {
462+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
463+
# CHECK: scf.yield %[[CONSTANT_4]] : i32
464+
# CHECK: }
465+
# CHECK: case 2 {
466+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
467+
# CHECK: scf.yield %[[CONSTANT_5]] : i32
468+
# CHECK: }
469+
# CHECK: default {
470+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
471+
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
472+
# CHECK: scf.yield %[[CONSTANT_0]] : i32
473+
# CHECK: }
474+
# CHECK: return %[[INDEX_SWITCH_0]] : i32
475+
# CHECK: }

0 commit comments

Comments
 (0)