Skip to content

Commit 6718483

Browse files
authored
Test do-while with and condition (#1277)
* Do not create undef init values for variables selected from init vars * Test do-while with and condition
1 parent 5b5f547 commit 6718483

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

src/enzyme_ad/jax/Passes/CanonicalizeFor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,15 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
14561456
if (blockArg.getOwner() == &condOp->getParentRegion()->front()) {
14571457
newArg = loop.getOperand(blockArg.getArgNumber());
14581458
}
1459+
} else if (auto selectOp = arg.getDefiningOp<arith::SelectOp>()) {
1460+
auto trueBlockArg = dyn_cast<BlockArgument>(selectOp.getTrueValue());
1461+
auto falseBlockArg = dyn_cast<BlockArgument>(selectOp.getFalseValue());
1462+
1463+
if (trueBlockArg && !falseBlockArg) {
1464+
newArg = loop.getOperand(trueBlockArg.getArgNumber());
1465+
} else if (!trueBlockArg && falseBlockArg) {
1466+
newArg = loop.getOperand(falseBlockArg.getArgNumber());
1467+
}
14591468
}
14601469
forArgs.push_back(newArg);
14611470
}

test/lit_tests/canonicalizefor/doWhile.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,61 @@ module @test_multiple_args_dynamic {
491491
// CHECK-NEXT: return %[[FOR]]#0, %[[FOR]]#1 : index, index
492492
// CHECK-NEXT: }
493493
// CHECK-NEXT: }
494+
495+
//----
496+
497+
// Loop condition is an and expression
498+
module @test_and_condition {
499+
func.func @do_while(%ub : i32) -> (i32, f32) {
500+
%cst = arith.constant 0.000000e+00 : f32
501+
%cst1 = arith.constant 1.000000e+00 : f32
502+
%c0_i32 = arith.constant 0 : i32
503+
%c1_i32 = arith.constant 1 : i32
504+
%true = arith.constant true
505+
%2:3 = scf.while (%arg10 = %c0_i32, %arg12 = %cst, %ac = %true) : (i32, f32, i1) -> (i32, f32, i1) {
506+
%3 = arith.cmpi ult, %arg10, %ub : i32
507+
%a = arith.andi %3, %ac : i1
508+
%p = arith.addi %arg10, %c1_i32 : i32
509+
%c = "test.something"() : () -> (i1)
510+
%4 = arith.addf %arg12, %cst1 : f32
511+
scf.condition(%a) %p, %4, %c : i32, f32, i1
512+
} do {
513+
^bb0(%arg10: i32, %arg12: f32, %ac: i1):
514+
scf.yield %arg10, %arg12, %ac : i32, f32, i1
515+
}
516+
return %2#0, %2#1 : i32, f32
517+
}
518+
}
519+
520+
// CHECK-LABEL: module @test_and_condition {
521+
// CHECK: func.func @do_while(%[[UB:.+]]: i32) -> (i32, f32) {
522+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
523+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
524+
// CHECK-DAG: %[[FALSE:.+]] = arith.constant false
525+
// CHECK-DAG: %[[TRUE:.+]] = arith.constant true
526+
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0.000000e+00 : f32
527+
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
528+
// CHECK-DAG: %[[UNDEF_I32:.+]] = ub.poison : i32
529+
// CHECK-DAG: %[[UNDEF_F32:.+]] = ub.poison : f32
530+
// CHECK-DAG: %[[UNDEF_I1:.+]] = ub.poison : i1
531+
// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[UB]], %[[C0]] : i32
532+
// CHECK-NEXT: %[[ADJ_UB:.+]] = arith.addi %[[MAX]], %[[C1]] : i32
533+
// CHECK-NEXT: %[[FOR:.+]]:7 = scf.for %[[IV:.+]] = %[[C0]] to %[[ADJ_UB]] step %[[C1]] iter_args(%[[ARG0:.+]] = %[[C0]], %[[ARG1:.+]] = %[[CST0]], %[[ARG2:.+]] = %[[TRUE]], %[[ARG3:.+]] = %[[UNDEF_I32]], %[[ARG4:.+]] = %[[UNDEF_F32]], %[[ARG5:.+]] = %[[UNDEF_I1]], %[[ARG6:.+]] = %[[TRUE]]) -> (i32, f32, i1, i32, f32, i1, i1) : i32 {
534+
// CHECK-NEXT: %[[IF1:.+]]:4 = scf.if %[[ARG6]] -> (i32, f32, i1, i1) {
535+
// CHECK-NEXT: %[[ADDI:.+]] = arith.addi %[[ARG0]], %[[C1]] : i32
536+
// CHECK-NEXT: %[[VAL:.+]] = "test.something"() : () -> i1
537+
// CHECK-NEXT: %[[ADDF:.+]] = arith.addf %[[ARG1]], %[[CST1]] : f32
538+
// CHECK-NEXT: scf.yield %[[ADDI]], %[[ADDF]], %[[VAL]], %[[ARG2]] : i32, f32, i1, i1
539+
// CHECK-NEXT: } else {
540+
// CHECK-NEXT: scf.yield %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[FALSE]] : i32, f32, i1, i1
541+
// CHECK-NEXT: }
542+
// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi slt, %[[IV]], %[[UB]] : i32
543+
// CHECK-NEXT: %[[COND:.+]] = arith.andi %[[CMP]], %[[IF1]]#3 : i1
544+
// CHECK-NEXT: %[[IF2:.+]]:3 = scf.if %[[COND]] -> (i32, f32, i1) {
545+
// CHECK-NEXT: scf.yield %[[IF1]]#0, %[[IF1]]#1, %[[IF1]]#2 : i32, f32, i1
546+
// CHECK-NEXT: } else {
547+
// CHECK-NEXT: scf.yield %[[UNDEF_I32]], %[[UNDEF_F32]], %[[UNDEF_I1]] : i32, f32, i1
548+
// CHECK-NEXT: }
549+
// CHECK-NEXT: scf.yield %[[IF2]]#0, %[[IF2]]#1, %[[IF2]]#2, %[[IF1]]#0, %[[IF1]]#1, %[[IF1]]#2, %[[IF1]]#3 : i32, f32, i1, i32, f32, i1, i1
550+
// CHECK-NEXT: }
551+
// CHECK-NEXT: return %[[FOR]]#3, %[[FOR]]#4 : i32, f32

test/lit_tests/canonicalizefor/while_to_for.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ func.func @foo(%arg0: memref<1x104x194xf64, 1>, %arg1: memref<35xf64, 1>, %arg2:
3737
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<1x104x194xf64, 1>,
3838
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<35xf64, 1>,
3939
// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<34xf64, 1>) {
40-
// CHECK-DAG: %[[undef_f64:.+]] = ub.poison : f64
4140
// CHECK-DAG: %[[c21:.*]] = arith.constant 21 : i64
4241
// CHECK-DAG: %[[c6:.*]] = arith.constant 6 : index
4342
// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
@@ -46,7 +45,7 @@ func.func @foo(%arg0: memref<1x104x194xf64, 1>, %arg1: memref<35xf64, 1>, %arg2:
4645
// CHECK: %[[VAL_9:.*]] = affine.load %[[VAL_0]][0, %[[VAL_7]] + 7, %[[VAL_8]] + 7] : memref<1x104x194xf64, 1>
4746
// CHECK: %[[VAL_10:.*]] = affine.load %[[VAL_1]][7] : memref<35xf64, 1>
4847
// CHECK: affine.store %[[VAL_10]], %[[VAL_0]][0, %[[VAL_7]] + 7, %[[VAL_8]] + 7] : memref<1x104x194xf64, 1>
49-
// CHECK: %[[VAL_11:.*]]:2 = scf.for %[[VAL_12:.*]] = %[[c1]] to %[[c21]] step %[[c1]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]], %[[arg7:.+]] = %[[undef_f64]]) -> (f64, f64) : i64 {
48+
// CHECK: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[c1]] to %[[c21]] step %[[c1]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (f64) : i64 {
5049
// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_12]] : i64 to index
5150
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[c7]] : index
5251
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_15]]] : memref<35xf64, 1>
@@ -56,9 +55,9 @@ func.func @foo(%arg0: memref<1x104x194xf64, 1>, %arg1: memref<35xf64, 1>, %arg2:
5655
// CHECK: %[[VAL_20:.*]] = arith.cmpf ole, %[[VAL_19]], %[[VAL_9]] : f64
5756
// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_20]], %[[VAL_16]], %[[VAL_13]] : f64
5857
// CHECK: affine.store %[[VAL_21]], %[[VAL_0]][0, %[[VAL_7]] + 7, %[[VAL_8]] + 7] : memref<1x104x194xf64, 1>
59-
// CHECK: scf.yield %[[VAL_21]], %[[VAL_21]] : f64, f64
58+
// CHECK: scf.yield %[[VAL_21]] : f64
6059
// CHECK: }
61-
// CHECK: "test.use"(%[[VAL_11]]#1, %[[c21]]) : (f64, i64) -> ()
60+
// CHECK: "test.use"(%[[VAL_11]], %[[c21]]) : (f64, i64) -> ()
6261
// CHECK: }
6362
// CHECK: return
6463
// CHECK: }

0 commit comments

Comments
 (0)