Skip to content

Commit e5b1689

Browse files
Formatting
1 parent ece9d37 commit e5b1689

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from typing import List, Optional, Sequence, Tuple, Union
2323

24+
2425
@_ods_cext.register_operation(_Dialect, replace=True)
2526
class ForOp(ForOp):
2627
"""Specialization for the SCF for op class."""
@@ -256,6 +257,7 @@ def for_(
256257
else:
257258
yield iv
258259

260+
259261
@_ods_cext.register_operation(_Dialect, replace=True)
260262
class IndexSwitchOp(IndexSwitchOp):
261263
__doc__ = IndexSwitchOp.__doc__
@@ -308,6 +310,7 @@ def case_blocks(self) -> Sequence[Block]:
308310
def case_block(self, i: int) -> Block:
309311
return self.case_regions[i].blocks[0]
310312

313+
311314
def index_switch(
312315
results_,
313316
arg,

mlir/test/python/dialects/scf.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,8 @@ def simple_if_else(cond):
363363

364364
@constructAndPrintInModule
365365
def testIndexSwitch():
366-
367366
i32 = T.i32()
367+
368368
@func.FuncOp.from_py_func(T.index(), results=[i32])
369369
def index_switch(index):
370370
c1 = arith.constant(i32, 1)
@@ -384,17 +384,44 @@ def index_switch(index):
384384

385385
for i, block in enumerate(switch_op.case_blocks):
386386
with InsertionPoint(block):
387-
scf.YieldOp([arith.constant(i32, i)])
387+
scf.yield_([arith.constant(i32, i)])
388388

389389
func.return_([switch_op.results[0]])
390390

391391
return index_switch
392392

393393

394+
# CHECK-LABEL: func.func @index_switch(
395+
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
396+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
397+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
398+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
399+
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
400+
# CHECK: case 0 {
401+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
402+
# CHECK: scf.yield %[[CONSTANT_3]] : i32
403+
# CHECK: }
404+
# CHECK: case 1 {
405+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
406+
# CHECK: scf.yield %[[CONSTANT_4]] : i32
407+
# CHECK: }
408+
# CHECK: case 2 {
409+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
410+
# CHECK: scf.yield %[[CONSTANT_5]] : i32
411+
# CHECK: }
412+
# CHECK: default {
413+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
414+
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
415+
# CHECK: scf.yield %[[CONSTANT_0]] : i32
416+
# CHECK: }
417+
# CHECK: return %[[INDEX_SWITCH_0]] : i32
418+
# CHECK: }
419+
420+
394421
@constructAndPrintInModule
395422
def testIndexSwitchWithBodyBuilders():
396-
397423
i32 = T.i32()
424+
398425
@func.FuncOp.from_py_func(T.index(), results=[i32])
399426
def index_switch(index):
400427
c1 = arith.constant(i32, 1)
@@ -406,7 +433,7 @@ def default_body_builder(switch_op):
406433
scf.yield_([c1])
407434

408435
def case_body_builder(switch_op, case_index: int, case_value: int):
409-
scf.YieldOp([arith.constant(i32, case_value)])
436+
scf.yield_([arith.constant(i32, case_value)])
410437

411438
result = scf.index_switch(
412439
results_=[i32],
@@ -419,3 +446,30 @@ def case_body_builder(switch_op, case_index: int, case_value: int):
419446
func.return_([result])
420447

421448
return index_switch
449+
450+
451+
# CHECK-LABEL: func.func @index_switch(
452+
# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
453+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
454+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
455+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
456+
# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
457+
# CHECK: case 0 {
458+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
459+
# CHECK: scf.yield %[[CONSTANT_3]] : i32
460+
# CHECK: }
461+
# CHECK: case 1 {
462+
# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
463+
# CHECK: scf.yield %[[CONSTANT_4]] : i32
464+
# CHECK: }
465+
# CHECK: case 2 {
466+
# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
467+
# CHECK: scf.yield %[[CONSTANT_5]] : i32
468+
# CHECK: }
469+
# CHECK: default {
470+
# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
471+
# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
472+
# CHECK: scf.yield %[[CONSTANT_0]] : i32
473+
# CHECK: }
474+
# CHECK: return %[[INDEX_SWITCH_0]] : i32
475+
# CHECK: }

0 commit comments

Comments
 (0)