Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,142 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
return success(*map != initialMap);
}

/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
/// into `replacementsMap`. If no entries were added to `replacementsMap`,
/// nothing was found.
static void shortenAddChainsContainingAll(
AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
if (!binOp)
return;
AffineExpr lhs = binOp.getLHS();
AffineExpr rhs = binOp.getRHS();
if (binOp.getKind() != AffineExprKind::Add) {
shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
return;
}
SmallVector<AffineExpr> toPreserve;
llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
AffineExpr thisTerm = rhs;
AffineExpr nextTerm = lhs;

while (thisTerm) {
if (!ourTracker.erase(thisTerm)) {
toPreserve.push_back(thisTerm);
shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
replacementsMap);
}
auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
thisTerm = nextTerm;
nextTerm = AffineExpr();
} else {
thisTerm = nextBinOp.getRHS();
nextTerm = nextBinOp.getLHS();
}
}
if (!ourTracker.empty())
return;
// We reverse the terms to be preserved here in order to preserve
// associativity between them.
AffineExpr newExpr = newVal;
for (AffineExpr preserved : llvm::reverse(toPreserve))
newExpr = newExpr + preserved;
replacementsMap.insert({e, newExpr});
}

/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
/// ...` (not necessarily in order) where the set of the `x_i` is the set of
/// outputs of an `affine.delinearize_index` whos inverse is that expression,
/// replace that expression with the input of that delinearize_index op.
///
/// `unitDimInput` is the input that was detected as the potential start to this
/// replacement chain - if it isn't the rightmost result of the delinearization,
/// this method fails. (This is intended to ensure we don't have redundant scans
/// over the same expression).
///
/// While this currently only handles delinearizations with a constant basis,
/// that isn't a fundamental limitation.
///
/// This is a utility function for `replaceDimOrSym` below.
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
if (!delinOp.getDynamicBasis().empty())
return failure();
if (resultToReplace != delinOp.getMultiIndex().back())
return failure();

MLIRContext *ctx = delinOp.getContext();
SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
for (auto [pos, dim] : llvm::enumerate(dims)) {
auto asResult = dyn_cast_if_present<OpResult>(dim);
if (!asResult)
continue;
if (asResult.getOwner() == delinOp.getOperation())
resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
}
for (auto [pos, sym] : llvm::enumerate(syms)) {
auto asResult = dyn_cast_if_present<OpResult>(sym);
if (!asResult)
continue;
if (asResult.getOwner() == delinOp.getOperation())
resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
}
if (llvm::any_of(resToExpr, [](auto e) { return e == AffineExpr(); })) {
return failure();
}

bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
int64_t stride = 1;
llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
// This isn't zip_equal since sometimes the delinearize basis is missing a
// size for the first result.
for (auto [binding, size] : llvm::zip(
llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
stride *= size;
}
if (resToExpr.size() != delinOp.getStaticBasis().size())
expectedExprs.insert(resToExpr[0] * stride);

DenseMap<AffineExpr, AffineExpr> replacements;
AffineExpr delinInExpr = isDimReplacement
? getAffineDimExpr(dims.size(), ctx)
: getAffineSymbolExpr(syms.size(), ctx);

for (AffineExpr e : map->getResults())
shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
if (replacements.empty())
return failure();

AffineMap origMap = *map;
if (isDimReplacement)
dims.push_back(delinOp.getLinearIndex());
else
syms.push_back(delinOp.getLinearIndex());
*map = origMap.replace(replacements, dims.size(), syms.size());

// Blank out dead dimensions and symbols
for (AffineExpr e : resToExpr) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
unsigned pos = d.getPosition();
if (!map->isFunctionOfDim(pos))
dims[pos] = nullptr;
}
if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
unsigned pos = s.getPosition();
if (!map->isFunctionOfSymbol(pos))
syms[pos] = nullptr;
}
}
return success();
}

/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
Expand Down Expand Up @@ -1157,6 +1293,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
syms);
}

if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
syms);
}

auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
Expand Down
130 changes: 130 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind

// -----

// CHECK-LABEL: func @delin_apply_cancel_exact
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]]
// CHECK-NOT: memref.store
// CHECK: return
func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) {
%a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index
%b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
%c:2 = affine.delinearize_index %arg0 into (20) : index, index

%t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0]
memref.store %t1, %arg1[%t1] : memref<?xindex>

%t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0]
memref.store %t2, %arg1[%t2] : memref<?xindex>

%t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1]
memref.store %t3, %arg1[%t3] : memref<?xindex>

%t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0]
memref.store %t4, %arg1[%t4] : memref<?xindex>

%t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0]
memref.store %t5, %arg1[%t5] : memref<?xindex>

%t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0]
memref.store %t6, %arg1[%t5] : memref<?xindex>

return
}

// -----

// CHECK-LABEL: func @delin_apply_cancel_exact_dim
// CHECK: affine.for %[[arg1:.+]] = 0 to 256
// CHECK: memref.store %[[arg1]]
// CHECK: return
func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) {
affine.for %arg1 = 0 to 256 {
%a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index
%i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1)
memref.store %i, %arg0[%i] : memref<?xindex>
}
return
}

// -----

// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)>
// CHECK-LABEL: func @delin_apply_cancel_const_term
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
// CHECK: return
func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) {
%a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index

%t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1]
memref.store %t1, %arg1[%t1] : memref<?xindex>

return
}

// -----

// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)>
// CHECK-LABEL: func @delin_apply_cancel_var_term
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index)
// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]]
// CHECK: return
func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) {
%a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index

%t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2]
memref.store %t1, %arg1[%t1] : memref<?xindex>

return
}

// -----

// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)>
// CHECK-LABEL: func @delin_apply_cancel_nested_exprs
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
// CHECK: return
func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) {
%a:2 = affine.delinearize_index %arg0 into (20) : index, index

%t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0]
memref.store %t1, %arg1[%t1] : memref<?xindex>

return
}

// -----

// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20)
// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]]
// CHECK: return
func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) {
%a:2 = affine.delinearize_index %arg0 into (20) : index, index

%t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0]
memref.store %t1, %arg1[%t1] : memref<?xindex>

return
}

// -----

// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)>
// CHECK-LABEL: func @delin_apply_dont_cancel_partial
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5)
// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1]
// CHECK: return
func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) {
%a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index

%t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1]
memref.store %t1, %arg1[%t1] : memref<?xindex>

return
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
Expand Down