-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][affine] Use value bound inference to determine minimum/maximum trip counts in loop analysis #128113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][affine] Use value bound inference to determine minimum/maximum trip counts in loop analysis #128113
Changes from 2 commits
23b3a7f
c834f4d
e865351
0b30c4e
fa68fe1
82e48ee
e31ff46
e58e115
372926c
80243dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| #include "mlir/Dialect/Affine/IR/AffineValueMap.h" | ||
| #include "mlir/Dialect/Affine/Utils.h" | ||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/IR/IRMapping.h" | ||
|
|
@@ -117,7 +118,8 @@ static void replaceIterArgsAndYieldResults(AffineForOp forOp) { | |
| /// was known to have a single iteration. | ||
| LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { | ||
| std::optional<uint64_t> tripCount = getConstantTripCount(forOp); | ||
| if (!tripCount || *tripCount != 1) | ||
| std::optional<uint64_t> maxTripCount = getMaxConstantTripCount(forOp); | ||
| if (!tripCount || *tripCount != 1 || !maxTripCount || *maxTripCount != 1) | ||
linuxlonelyeagle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return failure(); | ||
linuxlonelyeagle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // TODO: extend this for arbitrary affine bounds. | ||
|
|
@@ -160,7 +162,8 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { | |
| forOp.getBody()->back().erase(); | ||
| parentBlock->getOperations().splice(Block::iterator(forOp), | ||
| forOp.getBody()->getOperations()); | ||
| forOp.erase(); | ||
| IRRewriter b(forOp.getContext()); | ||
| b.eraseOp(forOp); | ||
linuxlonelyeagle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return success(); | ||
| } | ||
|
|
||
|
|
@@ -884,15 +887,23 @@ void mlir::affine::getTileableBands( | |
| /// Unrolls this loop completely. | ||
| LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) { | ||
| std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); | ||
| if (mayBeConstantTripCount.has_value()) { | ||
| uint64_t tripCount = *mayBeConstantTripCount; | ||
| if (tripCount == 0) | ||
| return success(); | ||
| if (tripCount == 1) | ||
| return promoteIfSingleIteration(forOp); | ||
| return loopUnrollByFactor(forOp, tripCount); | ||
| } | ||
| return failure(); | ||
| std::optional<uint64_t> maxMayBeConstantTripCount = | ||
| getMaxConstantTripCount(forOp); | ||
|
|
||
| if (!mayBeConstantTripCount.has_value() && | ||
| !maxMayBeConstantTripCount.has_value()) | ||
| return failure(); | ||
|
|
||
| uint64_t tripCount = *mayBeConstantTripCount; | ||
| uint64_t maxTripCount = *maxMayBeConstantTripCount; | ||
|
|
||
| // Trip equals 0, this loop cannot unroll. | ||
| if (tripCount <= 0) | ||
| return success(); | ||
|
|
||
| if (tripCount == 1 && maxTripCount == 1) | ||
|
||
| return promoteIfSingleIteration(forOp); | ||
linuxlonelyeagle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return loopUnrollByFactor(forOp, tripCount); | ||
| } | ||
|
|
||
| /// Unrolls this loop by the specified factor or by the trip count (if constant) | ||
|
|
@@ -1013,8 +1024,11 @@ LogicalResult mlir::affine::loopUnrollByFactor( | |
| assert(unrollFactor > 0 && "unroll factor should be positive"); | ||
|
|
||
| std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); | ||
| std::optional<uint64_t> maxMayBeConstantTripCount = | ||
| getMaxConstantTripCount(forOp); | ||
| if (unrollFactor == 1) { | ||
| if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 && | ||
| maxMayBeConstantTripCount && *maxMayBeConstantTripCount == 1 && | ||
| failed(promoteIfSingleIteration(forOp))) | ||
| return failure(); | ||
| return success(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| // UNROLL-BY-4-DAG: [[$MAP5:#map[0-9]*]] = affine_map<(d0)[s0] -> (d0 + s0 + 1)> | ||
| // UNROLL-BY-4-DAG: [[$MAP6:#map[0-9]*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)> | ||
| // UNROLL-BY-4-DAG: [[$MAP11:#map[0-9]*]] = affine_map<(d0) -> (d0)> | ||
| // UNROLL-BY-4-DAG: [[$MAP7:#map[0-9]*]] = affine_map<()[s0] -> (s0 + (((-s0 + 11) ceildiv 2) floordiv 4) * 8)> | ||
|
|
||
| // UNROLL-FULL-LABEL: func @loop_nest_simplest() { | ||
| func.func @loop_nest_simplest() { | ||
|
|
@@ -258,6 +259,71 @@ gpu.module @unroll_full { | |
| } | ||
| } | ||
|
|
||
| // UNROLL-FULL-LABEL: func @thread_partial_execution | ||
| func.func @thread_partial_execution() { | ||
| %0 = arith.constant 0 :index | ||
| %1 = arith.constant 2 : index | ||
| // UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index | ||
| gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %1, %sz_by = %1, %sz_bz = %1) | ||
|
||
| threads(%tx, %ty, %tz) in (%sz_tx = %1, %sz_ty = %1, %sz_tz = %1) { | ||
| affine.for %iv = %tx to 3 step 2 iter_args(%arg = %0) -> index { | ||
| %3 = arith.addi %arg, %0 : index | ||
| affine.yield %3 : index | ||
| } | ||
| // UNROLL-FULL: affine.for %{{.*}} = %{{.*}} to 3 step 2 iter_args(%[[ARG:.*]] = %[[C0]]) -> (index) { | ||
| // UNROLL-FULL-NEXT: %[[SUM:.*]] = arith.addi %[[ARG]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: affine.yield %[[SUM]] : index | ||
| // UNROLL-FULL-NEXT: } | ||
| gpu.terminator | ||
| } | ||
| return | ||
| } | ||
|
|
||
| // UNROLL-FULL-LABEL: func @unroll_all_thread | ||
| func.func @unroll_all_thread() { | ||
| %0 = arith.constant 0 :index | ||
| %1 = arith.constant 2 : index | ||
| // UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index | ||
| gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %1, %sz_by = %1, %sz_bz = %1) | ||
| threads(%tx, %ty, %tz) in (%sz_tx = %1, %sz_ty = %1, %sz_tz = %1) { | ||
| %threadid = gpu.thread_id x | ||
| %4 = affine.for %iv = %threadid to 6 step 2 iter_args(%arg = %0) -> index { | ||
| %3 = arith.addi %arg, %0 : index | ||
| affine.yield %3 : index | ||
| } | ||
| // UNROLL-FULL: %[[SUM_0:.*]] = arith.addi %[[C0]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[C0]] : index | ||
| gpu.terminator | ||
| } | ||
| return | ||
| } | ||
|
|
||
| // UNROLL-FULL-LABEL: func.func @partial_unroll_factor_4 | ||
| func.func @partial_unroll_factor_4() { | ||
| %0 = arith.constant 0 :index | ||
| %1 = arith.constant 2 : index | ||
| // UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index | ||
| gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %1, %sz_by = %1, %sz_bz = %1) | ||
| threads(%tx, %ty, %tz) in (%sz_tx = %1, %sz_ty = %1, %sz_tz = %1) { | ||
| %threadid = gpu.thread_id x | ||
| affine.for %iv = %threadid to 9 step 2 iter_args(%arg = %0) -> index { | ||
| %3 = arith.addi %arg, %0 : index | ||
| affine.yield %3 : index | ||
| } | ||
| gpu.terminator | ||
| } | ||
| // UNROLL-FULL: %[[ID:.*]] = gpu.thread_id x | ||
| // UNROLL-FULL-NEXT: affine.for %{{.*}} = %[[ID]] to 9 step 8 iter_args(%[[ARG:.*]] = %[[C0]]) -> (index) { | ||
| // UNROLL-FULL-NEXT: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[C0]] : index | ||
| // UNROLL-FULL-NEXT: affine.yield %[[SUM_3]] : index | ||
| // UNROLL-FULL-NEXT: } | ||
| return | ||
| } | ||
|
|
||
| // SHORT-LABEL: func @loop_nest_outer_unroll() { | ||
| func.func @loop_nest_outer_unroll() { | ||
| // SHORT: affine.for %arg0 = 0 to 4 { | ||
|
|
@@ -701,6 +767,32 @@ func.func @unroll_with_iter_args_and_promotion(%arg0 : f32, %arg1 : f32) -> f32 | |
| return %sum : f32 | ||
| } | ||
|
|
||
| // UNROLL-BY-4-LABEL: func @gpu_launch_unroll_by_factor_4 | ||
| func.func @gpu_launch_unroll_by_factor_4() { | ||
| %0 = arith.constant 0 :index | ||
| %1 = arith.constant 2 : index | ||
linuxlonelyeagle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // UNROLL-BY-4: %[[C0:.*]] = arith.constant 0 : index | ||
| gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %1, %sz_by = %1, %sz_bz = %1) | ||
| threads(%tx, %ty, %tz) in (%sz_tx = %1, %sz_ty = %1, %sz_tz = %1) { | ||
| %threadid = gpu.thread_id x | ||
| affine.for %iv = %threadid to 11 step 2 iter_args(%arg = %0) -> index { | ||
| %3 = arith.addi %arg, %0 : index | ||
| affine.yield %3 : index | ||
| } | ||
| gpu.terminator | ||
| } | ||
| // UNROLL-BY-4: %[[ID:.*]] = gpu.thread_id x | ||
| // UNROLL-BY-4-NEXT: %[[SUM_0:.*]] = arith.addi %[[C0]], %[[C0]] : index | ||
| // UNROLL-BY-4-NEXT: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[C0]] : index | ||
| // UNROLL-BY-4-NEXT: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[C0]] : index | ||
| // UNROLL-BY-4-NEXT: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[C0]] : index | ||
| // UNROLL-BY-4-NEXT: affine.for %[[VAL_20:.*]] = [[$MAP7]](){{\[}}%[[ID]]] to 11 step 2 iter_args(%[[ARG:.*]] = %[[SUM_3]]) -> (index) { | ||
| // UNROLL-BY-4-NEXT: %[[SUM_4:.*]] = arith.addi %[[ARG]], %[[C0]] : index | ||
| // UNROLL-BY-4-NEXT: affine.yield %[[SUM_4]] : index | ||
| // UNROLL-BY-4-NEXT: } | ||
| return | ||
| } | ||
|
|
||
| // UNROLL-FULL: func @unroll_zero_trip_count_case | ||
| func.func @unroll_zero_trip_count_case() { | ||
| // CHECK-NEXT: affine.for %{{.*}} = 0 to 0 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.