|
1 | 1 | # RUN: %PYTHON %s | FileCheck %s |
2 | 2 |
|
3 | 3 | from mlir.ir import * |
4 | | -from mlir.dialects import arith |
5 | | -from mlir.dialects import func |
6 | | -from mlir.dialects import memref |
7 | | -from mlir.dialects import scf |
| 4 | +from mlir.extras import types as T |
| 5 | +from mlir.dialects import ( |
| 6 | + arith, |
| 7 | + func, |
| 8 | + memref, |
| 9 | + scf, |
| 10 | + cf, |
| 11 | +) |
8 | 12 | from mlir.passmanager import PassManager |
9 | 13 |
|
10 | 14 |
|
@@ -355,3 +359,117 @@ def simple_if_else(cond): |
355 | 359 | # CHECK: scf.yield %[[TWO]], %[[THREE]] |
356 | 360 | # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 |
357 | 361 | # CHECK: return |
| 362 | + |
| 363 | + |
| 364 | +@constructAndPrintInModule |
| 365 | +def testIndexSwitch(): |
| 366 | + i32 = T.i32() |
| 367 | + |
| 368 | + @func.FuncOp.from_py_func(T.index(), results=[i32]) |
| 369 | + def index_switch(index): |
| 370 | + c1 = arith.constant(i32, 1) |
| 371 | + c0 = arith.constant(i32, 0) |
| 372 | + value = arith.constant(i32, 5) |
| 373 | + switch_op = scf.IndexSwitchOp([i32], index, range(3)) |
| 374 | + |
| 375 | + assert switch_op.regions[0] == switch_op.default_region |
| 376 | + assert switch_op.regions[1] == switch_op.case_regions[0] |
| 377 | + assert switch_op.regions[1] == switch_op.case_region(0) |
| 378 | + assert len(switch_op.case_regions) == 3 |
| 379 | + assert len(switch_op.regions) == 4 |
| 380 | + |
| 381 | + with InsertionPoint(switch_op.default_block): |
| 382 | + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") |
| 383 | + scf.yield_([c1]) |
| 384 | + |
| 385 | + for i, block in enumerate(switch_op.case_blocks): |
| 386 | + with InsertionPoint(block): |
| 387 | + scf.yield_([arith.constant(i32, i)]) |
| 388 | + |
| 389 | + func.return_([switch_op.results[0]]) |
| 390 | + |
| 391 | + return index_switch |
| 392 | + |
| 393 | + |
| 394 | +# CHECK-LABEL: func.func @index_switch( |
| 395 | +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { |
| 396 | +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 |
| 397 | +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 |
| 398 | +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 |
| 399 | +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 |
| 400 | +# CHECK: case 0 { |
| 401 | +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 |
| 402 | +# CHECK: scf.yield %[[CONSTANT_3]] : i32 |
| 403 | +# CHECK: } |
| 404 | +# CHECK: case 1 { |
| 405 | +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 |
| 406 | +# CHECK: scf.yield %[[CONSTANT_4]] : i32 |
| 407 | +# CHECK: } |
| 408 | +# CHECK: case 2 { |
| 409 | +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 |
| 410 | +# CHECK: scf.yield %[[CONSTANT_5]] : i32 |
| 411 | +# CHECK: } |
| 412 | +# CHECK: default { |
| 413 | +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false |
| 414 | +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" |
| 415 | +# CHECK: scf.yield %[[CONSTANT_0]] : i32 |
| 416 | +# CHECK: } |
| 417 | +# CHECK: return %[[INDEX_SWITCH_0]] : i32 |
| 418 | +# CHECK: } |
| 419 | + |
| 420 | + |
| 421 | +@constructAndPrintInModule |
| 422 | +def testIndexSwitchWithBodyBuilders(): |
| 423 | + i32 = T.i32() |
| 424 | + |
| 425 | + @func.FuncOp.from_py_func(T.index(), results=[i32]) |
| 426 | + def index_switch(index): |
| 427 | + c1 = arith.constant(i32, 1) |
| 428 | + c0 = arith.constant(i32, 0) |
| 429 | + value = arith.constant(i32, 5) |
| 430 | + |
| 431 | + def default_body_builder(switch_op): |
| 432 | + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") |
| 433 | + scf.yield_([c1]) |
| 434 | + |
| 435 | + def case_body_builder(switch_op, case_index: int, case_value: int): |
| 436 | + scf.yield_([arith.constant(i32, case_value)]) |
| 437 | + |
| 438 | + result = scf.index_switch( |
| 439 | + results=[i32], |
| 440 | + arg=index, |
| 441 | + cases=range(3), |
| 442 | + case_body_builder=case_body_builder, |
| 443 | + default_body_builder=default_body_builder, |
| 444 | + ) |
| 445 | + |
| 446 | + func.return_([result]) |
| 447 | + |
| 448 | + return index_switch |
| 449 | + |
| 450 | + |
| 451 | +# CHECK-LABEL: func.func @index_switch( |
| 452 | +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { |
| 453 | +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 |
| 454 | +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 |
| 455 | +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 |
| 456 | +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 |
| 457 | +# CHECK: case 0 { |
| 458 | +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 |
| 459 | +# CHECK: scf.yield %[[CONSTANT_3]] : i32 |
| 460 | +# CHECK: } |
| 461 | +# CHECK: case 1 { |
| 462 | +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 |
| 463 | +# CHECK: scf.yield %[[CONSTANT_4]] : i32 |
| 464 | +# CHECK: } |
| 465 | +# CHECK: case 2 { |
| 466 | +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 |
| 467 | +# CHECK: scf.yield %[[CONSTANT_5]] : i32 |
| 468 | +# CHECK: } |
| 469 | +# CHECK: default { |
| 470 | +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false |
| 471 | +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" |
| 472 | +# CHECK: scf.yield %[[CONSTANT_0]] : i32 |
| 473 | +# CHECK: } |
| 474 | +# CHECK: return %[[INDEX_SWITCH_0]] : i32 |
| 475 | +# CHECK: } |
0 commit comments