Skip to content

Commit 9b1890c

Browse files
committed
[mlir][Affine] Split off delinearize parts that depend on last component
If we have %0 = affine.linearize_index disjoint [%a, %b] by (A, B) %1:3 = affine.delinearize_index %0 into (A, B1, B2) where B = B1 * B2 (or some mor complex product), we can simplify this to %0 = affine.linearize_index disjoint [%a] by (A) %1a:1 = affine.delinearize_index %0 into (A) %1b:2 = affine.delinearize_index %b into (B1, B2) This, and more complex cases, prevent us from adding terms together only to divide them away from each other.
1 parent 565a9ac commit 9b1890c

File tree

2 files changed

+151
-2
lines changed

2 files changed

+151
-2
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact
46944694
return success();
46954695
}
46964696
};
4697+
4698+
/// If the input to a delinearization is a disjoint linearization, and the
4699+
/// last k > 1 components of the delinearization basis multiply to the
4700+
/// last component of the linearization basis, break the linearization and
4701+
/// delinearization into two parts, peeling off the last input to linearization.
4702+
///
4703+
/// For example:
4704+
/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4705+
/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4706+
/// becomes
4707+
/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4708+
/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
4709+
/// %2:2 = affine.delinearize_index %x by (8, 4) : index
4710+
/// where the original %1:4 is replaced by %1:2 ++ %2:2
4711+
struct SplitDelinearizeSpanningLastLinearizeArg final
4712+
: OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4713+
using OpRewritePattern::OpRewritePattern;
4714+
4715+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4716+
PatternRewriter &rewriter) const override {
4717+
auto linearizeOp = delinearizeOp.getLinearIndex()
4718+
.getDefiningOp<affine::AffineLinearizeIndexOp>();
4719+
if (!linearizeOp)
4720+
return rewriter.notifyMatchFailure(delinearizeOp,
4721+
"index doesn't come from linearize");
4722+
4723+
if (!linearizeOp.getDisjoint())
4724+
return rewriter.notifyMatchFailure(linearizeOp,
4725+
"linearize isn't disjoint");
4726+
4727+
int64_t target = linearizeOp.getStaticBasis().back();
4728+
if (ShapedType::isDynamic(target))
4729+
return rewriter.notifyMatchFailure(
4730+
linearizeOp, "linearize ends with dynamic basis value");
4731+
4732+
int64_t sizeToSplit = 1;
4733+
size_t elemsToSplit = 0;
4734+
ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4735+
for (int64_t basisElem : llvm::reverse(basis)) {
4736+
if (ShapedType::isDynamic(basisElem))
4737+
return rewriter.notifyMatchFailure(
4738+
delinearizeOp, "dynamic basis element while scanning for split");
4739+
sizeToSplit *= basisElem;
4740+
elemsToSplit += 1;
4741+
4742+
if (sizeToSplit > target)
4743+
return rewriter.notifyMatchFailure(delinearizeOp,
4744+
"overshot last argument size");
4745+
if (sizeToSplit == target)
4746+
break;
4747+
}
4748+
4749+
if (sizeToSplit < target)
4750+
return rewriter.notifyMatchFailure(
4751+
delinearizeOp, "product of known basis elements doesn't exceed last "
4752+
"linearize argument");
4753+
4754+
if (elemsToSplit < 2)
4755+
return rewriter.notifyMatchFailure(
4756+
delinearizeOp, "don't have a non-trivial basis product");
4757+
4758+
Value linearizeWithoutBack =
4759+
rewriter.create<affine::AffineLinearizeIndexOp>(
4760+
linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4761+
linearizeOp.getDynamicBasis(),
4762+
linearizeOp.getStaticBasis().drop_back(),
4763+
linearizeOp.getDisjoint());
4764+
auto delinearizeWithoutSplitPart =
4765+
rewriter.create<affine::AffineDelinearizeIndexOp>(
4766+
delinearizeOp.getLoc(), linearizeWithoutBack,
4767+
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4768+
delinearizeOp.hasOuterBound());
4769+
auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4770+
delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4771+
basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4772+
SmallVector<Value> results = llvm::to_vector(
4773+
llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4774+
delinearizeBack.getResults()));
4775+
rewriter.replaceOp(delinearizeOp, results);
4776+
4777+
return success();
4778+
}
4779+
};
46974780
} // namespace
46984781

46994782
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
47004783
RewritePatternSet &patterns, MLIRContext *context) {
47014784
patterns
4702-
.insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
4703-
context);
4785+
.insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis,
4786+
SplitDelinearizeSpanningLastLinearizeArg>(context);
47044787
}
47054788

47064789
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
17771777

17781778
// -----
17791779

1780+
// CHECK-LABEL: func @split_delinearize_spanning_final_part
1781+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1782+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1783+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1784+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
1785+
// CHECK: %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
1786+
// CHECK: %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
1787+
// CHECK: return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
1788+
func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
1789+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
1790+
%1:4 = affine.delinearize_index %0 into (2, 8, 8)
1791+
: index, index, index, index
1792+
return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
1793+
}
1794+
1795+
// -----
1796+
1797+
// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
1798+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1799+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1800+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1801+
// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
1802+
// CHECK: return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
1803+
func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
1804+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
1805+
%1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
1806+
: index, index, index, index
1807+
return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
1808+
}
1809+
1810+
// -----
1811+
1812+
// The delinearize basis doesn't match the last basis element before
1813+
// overshooting it, don't simplify.
1814+
// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
1815+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1816+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1817+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1818+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
1819+
// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
1820+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
1821+
func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
1822+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
1823+
%1:4 = affine.delinearize_index %0 into (2, 16, 8)
1824+
: index, index, index, index
1825+
return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
1826+
}
1827+
1828+
// -----
1829+
1830+
// The delinearize basis doesn't fully multiply to the final basis element.
1831+
// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
1832+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1833+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
1834+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
1835+
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
1836+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1
1837+
func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
1838+
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
1839+
%1:3 = affine.delinearize_index %0 into (4, 8)
1840+
: index, index, index
1841+
return %1#0, %1#1, %1#2 : index, index, index
1842+
}
1843+
1844+
// -----
1845+
17801846
// CHECK-LABEL: @linearize_unit_basis_disjoint
17811847
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
17821848
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index

0 commit comments

Comments
 (0)