Skip to content

Commit f0ecc8f

Browse files
committed
[MLIR][SCF] Expose scf util functions in header file
This patch exposes the `delinearizeInductionVariable` and `getProductOfIntsOrIndexes` helpers in the header file for the SCF utils, as these are useful for downstream users. Additionally, `getProductOfIntsOrIndexes` will now constant-fold the generated `arith::MulIOp`.
1 parent a04ab7b commit f0ecc8f

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,22 @@ void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
148148
Value normalizedIv, OpFoldResult origLb,
149149
OpFoldResult origStep);
150150

151+
/// For each original loop, the value of the induction variable can be obtained
152+
/// by dividing the induction variable of the linearized loop by the total
153+
/// number of iterations of the loops nested in it modulo the number of
154+
/// iterations in this loop (remove the values related to the outer loops):
155+
/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
156+
/// Compute these iteratively from the innermost loop by creating a "running
157+
/// quotient" of division by the range.
158+
/// Returns the delinearized induction variables and the preserved users.
159+
std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
160+
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
161+
Value linearizedIv, ArrayRef<Value> ubs);
162+
163+
/// Helper function to multiply a sequence of values.
164+
Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
165+
ArrayRef<Value> values);
166+
151167
/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
152168
/// parametric tile sizes that the outer loops have a fixed number of iterations
153169
/// as defined in `sizes`.

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Arith/Utils/Utils.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/IR/BuiltinOps.h"
2122
#include "mlir/IR/IRMapping.h"
2223
#include "mlir/IR/OpDefinition.h"
@@ -807,7 +808,7 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
807808

808809
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
809810
ArrayRef<OpFoldResult> values) {
810-
assert(!values.empty() && "unexecpted empty array");
811+
assert(!values.empty() && "unexpected empty array");
811812
AffineExpr s0, s1;
812813
bindSymbols(rewriter.getContext(), s0, s1);
813814
AffineExpr mul = s0 * s1;
@@ -819,9 +820,8 @@ static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
819820
return products;
820821
}
821822

822-
/// Helper function to multiply a sequence of values.
823-
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
824-
ArrayRef<Value> values) {
823+
Value mlir::getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
824+
ArrayRef<Value> values) {
825825
assert(!values.empty() && "unexpected empty list");
826826
if (getType(values.front()).isIndex()) {
827827
SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
@@ -835,7 +835,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
835835
continue;
836836
if (productOf)
837837
productOf =
838-
rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
838+
rewriter.createOrFold<arith::MulIOp>(loc, productOf.value(), v);
839839
else
840840
productOf = v;
841841
}
@@ -848,17 +848,9 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
848848
return productOf.value();
849849
}
850850

851-
/// For each original loop, the value of the
852-
/// induction variable can be obtained by dividing the induction variable of
853-
/// the linearized loop by the total number of iterations of the loops nested
854-
/// in it modulo the number of iterations in this loop (remove the values
855-
/// related to the outer loops):
856-
/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
857-
/// Compute these iteratively from the innermost loop by creating a "running
858-
/// quotient" of division by the range.
859-
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
860-
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
861-
Value linearizedIv, ArrayRef<Value> ubs) {
851+
std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
852+
mlir::delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
853+
Value linearizedIv, ArrayRef<Value> ubs) {
862854

863855
if (linearizedIv.getType().isIndex()) {
864856
Operation *delinearizedOp =

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,13 +500,11 @@ func.func @coalesce_i32_loops() {
500500
%1 = arith.constant 128 : i32
501501
%2 = arith.constant 2 : i32
502502
%3 = arith.constant 64 : i32
503-
// CHECK: %[[VAL_4:.*]] = arith.constant 64 : i32
504503
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
505504
// CHECK: %[[ONE:.*]] = arith.constant 1 : i32
506-
// CHECK: %[[VAL_7:.*]] = arith.constant 32 : i32
507505
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
508506
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : i32
509-
// CHECK: %[[UB:.*]] = arith.muli %[[VAL_4]], %[[VAL_7]] : i32
507+
// CHECK: %[[UB:.*]] = arith.constant 2048 : i32
510508
// CHECK: scf.for %[[VAL_11:.*]] = %[[ZERO]] to %[[UB]] step %[[ONE]] : i32 {
511509
scf.for %i = %0 to %1 step %2 : i32 {
512510
scf.for %j = %0 to %3 step %2 : i32 {

0 commit comments

Comments
 (0)