diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index e3717aa9d940e..c7588b433dee1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -240,10 +240,10 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp, loc, forOp.getUpperBound().getType(), splitBound); // Peel the first iteration. - IRMapping map; - map.map(forOp.getUpperBound(), splitBound); - firstIteration = cast(b.clone(*forOp.getOperation(), map)); - + firstIteration = cast(b.clone(*forOp.getOperation())); + b.modifyOpInPlace(firstIteration, [&]() { + firstIteration.getUpperBoundMutable().assign(splitBound); + }); // Update main loop with new lower bound. b.modifyOpInPlace(forOp, [&]() { forOp.getInitArgsMutable().assign(firstIteration->getResults()); diff --git a/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir index fe3b3e686a3e4..1737c6b6f6c5f 100644 --- a/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling-front.mlir @@ -8,7 +8,7 @@ // CHECK-DAG: %[[C17:.*]] = arith.constant 17 : index // CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] // CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) { -// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]] +// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C17]], %[[IV]])[%[[C4]]] // CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32 // CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32 // CHECK: scf.yield %[[INIT]] @@ -35,7 +35,49 @@ func.func @fully_static_bounds() -> i32 { } return %r : i32 } +// ----- +// CHECK-LABEL: func.func @static_two_iterations_ub_used_in_loop( +// CHECK-SAME: %[[IFM1:.*]]: memref<1xi32>) -> i32 { +// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[C0_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[C4_IDX:.*]] = arith.constant 4 : index +// CHECK: %[[C7_IDX:.*]] = arith.constant 7 : index +// CHECK: %[[FOR1:.*]] = scf.for %[[IV1:.*]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C4_IDX]] iter_args(%[[ARG1:.*]] = %[[C0_I32]]) -> (i32) { +// CHECK: %[[AFFINE_MIN1:.*]] = affine.min #map(%[[C7_IDX]], %[[IV1]]){{\[}}%[[C4_IDX]]] +// CHECK: %[[CAST1:.*]] = arith.index_cast %[[AFFINE_MIN1]] : index to i32 +// CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG1]], %[[CAST1]] : i32 +// CHECK: %[[LOAD1:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32> +// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI1]], %[[LOAD1]] : i32 +// CHECK: scf.yield %[[ADDI2]] : i32 +// CHECK: } +// CHECK: %[[FOR2:.*]] = scf.for %[[IV2:.*]] = %[[C4_IDX]] to %[[C7_IDX]] step %[[C4_IDX]] iter_args(%[[ARG2:.*]] = %[[RESULT1:.*]]) -> (i32) { +// CHECK: %[[AFFINE_MIN2:.*]] = affine.min #map(%[[C7_IDX]], %[[IV2]]){{\[}}%[[C4_IDX]]] +// CHECK: %[[CAST2:.*]] = arith.index_cast %[[AFFINE_MIN2]] : index to i32 +// CHECK: %[[ADDI3:.*]] = arith.addi %[[ARG2]], %[[CAST2]] : i32 +// CHECK: %[[LOAD2:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32> +// CHECK: %[[ADDI4:.*]] = arith.addi %[[ADDI3]], %[[LOAD2]] : i32 +// CHECK: scf.yield %[[ADDI4]] : i32 +// CHECK: } +// CHECK: return %[[RESULT2:.*]] : i32 +// CHECK: } + +#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)> +func.func @static_two_iterations_ub_used_in_loop(%arg0: memref<1xi32>) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %lb = arith.constant 0 : index + %step = arith.constant 4 : index + %ub = arith.constant 7 : index + %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 { + %s = affine.min #map(%ub, %iv)[%step] + %casted = arith.index_cast %s : index to i32 + %0 = arith.addi %arg, %casted : i32 + %1 = memref.load %arg0[%ub] : memref<1xi32> + %result = arith.addi %0, %1 : i32 + scf.yield %result : i32 + } + return %r : i32 +} // ----- // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1)[s0] -> (4, d0 - d1)> @@ -44,7 +86,7 @@ func.func @fully_static_bounds() -> i32 { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C4]] { -// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C4]], %[[IV]])[%[[C4]]] +// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[UB]], %[[IV]])[%[[C4]]] // CHECK: %[[LOAD:.*]] = memref.load %[[MEMREF]][] // CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] // CHECK: %[[ADD:.*]] = arith.addi %[[LOAD]], %[[CAST]] : i32 @@ -83,7 +125,7 @@ func.func @no_loop_results(%ub : index, %d : memref) { // CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[STEP]]] // CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]] // CHECK-SAME: step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) { -// CHECK: %[[MIN:.*]] = affine.min #[[MAP1]](%[[NEW_UB]], %[[IV]])[%[[STEP]]] +// CHECK: %[[MIN:.*]] = affine.min #[[MAP1]](%[[UB]], %[[IV]])[%[[STEP]]] // CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32 // CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32 // CHECK: scf.yield %[[ADD]] @@ -119,7 +161,7 @@ func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 { // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK: %[[FIRST:.*]] = scf.for %[[IV:.*]] = %[[C2]] to %[[C6]] // CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) { -// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C6]], %[[IV]])[%[[C4]]] +// CHECK: %[[MIN:.*]] = affine.min #[[MAP]](%[[C8]], %[[IV]])[%[[C4]]] // CHECK: %[[CAST:.*]] = arith.index_cast %[[MIN]] : index to i32 // CHECK: %[[INIT:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32 // CHECK: scf.yield %[[INIT]] diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir index be58548b1fcfe..084576625f32c 100644 --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -35,6 +35,35 @@ func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 { // ----- +// CHECK-LABEL: func.func @static_two_iterations_ub_used_in_loop( +// CHECK-SAME: %[[IFM1:.*]]: memref<1xi32>) -> i32 { +// CHECK: %[[C7_I32:.*]] = arith.constant 7 : i32 +// CHECK: %[[C7_IDX:.*]] = arith.constant 7 : index +// CHECK: %[[LOAD1:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32> +// CHECK: %[[ADDI1:.*]] = arith.addi %[[LOAD1]], %[[C7_I32]] : i32 +// CHECK: %[[LOAD2:.*]] = memref.load %[[IFM1]]{{\[}}%[[C7_IDX]]] : memref<1xi32> +// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI1]], %[[LOAD2]] : i32 +// CHECK: return %[[ADDI2]] : i32 +// CHECK: } + +#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)> +func.func @static_two_iterations_ub_used_in_loop(%arg0: memref<1xi32>) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %lb = arith.constant 0 : index + %step = arith.constant 4 : index + %ub = arith.constant 7 : index + %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0_i32) -> i32 { + %s = affine.min #map(%ub, %iv)[%step] + %casted = arith.index_cast %s : index to i32 + %0 = arith.addi %arg, %casted : i32 + %1 = memref.load %arg0[%ub] : memref<1xi32> + %result = arith.addi %0, %1 : i32 + scf.yield %result : i32 + } + return %r : i32 +} +// ----- + // CHECK: func @fully_static_bounds( // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32