Skip to content

Commit ce17599

Browse files
[MLIR][Python] Add wrappers for scf.index_switch (llvm#167458)
The C++ index switch op has utilities for `getCaseBlock(int i)` and `getDefaultBlock()`, so these have been added. Optional body builder args have been added: one for the default case and one for the switch cases.
1 parent b07f8b0 commit ce17599

File tree

3 files changed

+198
-7
lines changed

3 files changed

+198
-7
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ._ods_common import (
1313
get_op_result_or_value as _get_op_result_or_value,
1414
get_op_results_or_values as _get_op_results_or_values,
15+
get_op_result_or_op_results as _get_op_result_or_op_results,
1516
_cext as _ods_cext,
1617
)
1718
except ImportError as e:
@@ -254,3 +255,77 @@ def for_(
254255
yield iv, iter_args[0], for_op.results[0]
255256
else:
256257
yield iv
258+
259+
260+
@_ods_cext.register_operation(_Dialect, replace=True)
261+
class IndexSwitchOp(IndexSwitchOp):
262+
__doc__ = IndexSwitchOp.__doc__
263+
264+
def __init__(
265+
self,
266+
results,
267+
arg,
268+
cases,
269+
case_body_builder=None,
270+
default_body_builder=None,
271+
loc=None,
272+
ip=None,
273+
):
274+
cases = DenseI64ArrayAttr.get(cases)
275+
super().__init__(
276+
results, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
277+
)
278+
for region in self.regions:
279+
region.blocks.append()
280+
281+
if default_body_builder is not None:
282+
with InsertionPoint(self.default_block):
283+
default_body_builder(self)
284+
285+
if case_body_builder is not None:
286+
for i, case in enumerate(cases):
287+
with InsertionPoint(self.case_block(i)):
288+
case_body_builder(self, i, self.cases[i])
289+
290+
@property
291+
def default_region(self) -> Region:
292+
return self.regions[0]
293+
294+
@property
295+
def default_block(self) -> Block:
296+
return self.default_region.blocks[0]
297+
298+
@property
299+
def case_regions(self) -> Sequence[Region]:
300+
return self.regions[1:]
301+
302+
def case_region(self, i: int) -> Region:
303+
return self.case_regions[i]
304+
305+
@property
306+
def case_blocks(self) -> Sequence[Block]:
307+
return [region.blocks[0] for region in self.case_regions]
308+
309+
def case_block(self, i: int) -> Block:
310+
return self.case_regions[i].blocks[0]
311+
312+
313+
def index_switch(
314+
results,
315+
arg,
316+
cases,
317+
case_body_builder=None,
318+
default_body_builder=None,
319+
loc=None,
320+
ip=None,
321+
) -> Union[OpResult, OpResultList, IndexSwitchOp]:
322+
op = IndexSwitchOp(
323+
results=results,
324+
arg=arg,
325+
cases=cases,
326+
case_body_builder=case_body_builder,
327+
default_body_builder=default_body_builder,
328+
loc=loc,
329+
ip=ip,
330+
)
331+
return _get_op_result_or_op_results(op)

mlir/test/python/dialects/scf.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
from mlir.dialects import arith
5-
from mlir.dialects import func
6-
from mlir.dialects import memref
7-
from mlir.dialects import scf
4+
from mlir.extras import types as T
5+
from mlir.dialects import (
6+
arith,
7+
func,
8+
memref,
9+
scf,
10+
cf,
11+
)
812
from mlir.passmanager import PassManager
913

1014

@@ -355,3 +359,117 @@ def simple_if_else(cond):
355359
# CHECK: scf.yield %[[TWO]], %[[THREE]]
356360
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
357361
# CHECK: return
362+
363+
364+
@constructAndPrintInModule
365+
def testIndexSwitch():
366+
i32 = T.i32()
367+
368+
@func.FuncOp.from_py_func(T.index(), results=[i32])
369+
def index_switch(index):
370+
c1 = arith.constant(i32, 1)
371+
c0 = arith.constant(i32, 0)
372+
value = arith.constant(i32, 5)
373+
switch_op = scf.IndexSwitchOp([i32], index, range(3))
374+
375+
assert switch_op.regions[0] == switch_op.default_region
376+
assert switch_op.regions[1] == switch_op.case_regions[0]
377+
assert switch_op.regions[1] == switch_op.case_region(0)
378+
assert len(switch_op.case_regions) == 3
379+
assert len(switch_op.regions) == 4
380+
381+
with InsertionPoint(switch_op.default_block):
382+
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
383+
scf.yield_([c1])
384+
385+
for i, block in enumerate(switch_op.case_blocks):
386+
with InsertionPoint(block):
387+
scf.yield_([arith.constant(i32, i)])
388+
389+
func.return_([switch_op.results[0]])
390+
391+
return index_switch
392+
393+
394+
# CHECK-LABEL: func.func @index_switch(
395+
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
396+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
397+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
398+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
399+
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
400+
# CHECK: case 0 {
401+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
402+
# CHECK: scf.yield %[[CONSTANT_3]] : i32
403+
# CHECK: }
404+
# CHECK: case 1 {
405+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
406+
# CHECK: scf.yield %[[CONSTANT_4]] : i32
407+
# CHECK: }
408+
# CHECK: case 2 {
409+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
410+
# CHECK: scf.yield %[[CONSTANT_5]] : i32
411+
# CHECK: }
412+
# CHECK: default {
413+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
414+
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
415+
# CHECK: scf.yield %[[CONSTANT_0]] : i32
416+
# CHECK: }
417+
# CHECK: return %[[INDEX_SWITCH_0]] : i32
418+
# CHECK: }
419+
420+
421+
@constructAndPrintInModule
422+
def testIndexSwitchWithBodyBuilders():
423+
i32 = T.i32()
424+
425+
@func.FuncOp.from_py_func(T.index(), results=[i32])
426+
def index_switch(index):
427+
c1 = arith.constant(i32, 1)
428+
c0 = arith.constant(i32, 0)
429+
value = arith.constant(i32, 5)
430+
431+
def default_body_builder(switch_op):
432+
cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
433+
scf.yield_([c1])
434+
435+
def case_body_builder(switch_op, case_index: int, case_value: int):
436+
scf.yield_([arith.constant(i32, case_value)])
437+
438+
result = scf.index_switch(
439+
results=[i32],
440+
arg=index,
441+
cases=range(3),
442+
case_body_builder=case_body_builder,
443+
default_body_builder=default_body_builder,
444+
)
445+
446+
func.return_([result])
447+
448+
return index_switch
449+
450+
451+
# CHECK-LABEL: func.func @index_switch(
452+
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
453+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
454+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
455+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
456+
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
457+
# CHECK: case 0 {
458+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
459+
# CHECK: scf.yield %[[CONSTANT_3]] : i32
460+
# CHECK: }
461+
# CHECK: case 1 {
462+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
463+
# CHECK: scf.yield %[[CONSTANT_4]] : i32
464+
# CHECK: }
465+
# CHECK: case 2 {
466+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
467+
# CHECK: scf.yield %[[CONSTANT_5]] : i32
468+
# CHECK: }
469+
# CHECK: default {
470+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
471+
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
472+
# CHECK: scf.yield %[[CONSTANT_0]] : i32
473+
# CHECK: }
474+
# CHECK: return %[[INDEX_SWITCH_0]] : i32
475+
# CHECK: }

mlir/test/python/ir/operation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,9 +1212,7 @@ def testIndexSwitch():
12121212
@func.FuncOp.from_py_func(T.index())
12131213
def index_switch(index):
12141214
c1 = arith.constant(i32, 1)
1215-
switch_op = scf.IndexSwitchOp(
1216-
results_=[i32], arg=index, cases=range(3), num_caseRegions=3
1217-
)
1215+
switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3))
12181216

12191217
assert len(switch_op.regions) == 4
12201218
assert len(switch_op.regions[2:]) == 2

0 commit comments

Comments
 (0)