@@ -4684,6 +4684,115 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
46844684 patterns.insert <DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
46854685}
46864686
4687+ // ===----------------------------------------------------------------------===//
4688+ // LinearizeIndexOp
4689+ // ===----------------------------------------------------------------------===//
4690+
4691+ void AffineLinearizeIndexOp::build (OpBuilder &odsBuilder,
4692+ OperationState &odsState,
4693+ ValueRange multiIndex, ValueRange basis,
4694+ bool disjoint) {
4695+ SmallVector<Value> dynamicBasis;
4696+ SmallVector<int64_t > staticBasis;
4697+ dispatchIndexOpFoldResults (getAsOpFoldResult (basis), dynamicBasis,
4698+ staticBasis);
4699+ build (odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4700+ }
4701+
4702+ void AffineLinearizeIndexOp::build (OpBuilder &odsBuilder,
4703+ OperationState &odsState,
4704+ ValueRange multiIndex,
4705+ ArrayRef<OpFoldResult> basis,
4706+ bool disjoint) {
4707+ SmallVector<Value> dynamicBasis;
4708+ SmallVector<int64_t > staticBasis;
4709+ dispatchIndexOpFoldResults (basis, dynamicBasis, staticBasis);
4710+ build (odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4711+ }
4712+
4713+ void AffineLinearizeIndexOp::build (OpBuilder &odsBuilder,
4714+ OperationState &odsState,
4715+ ValueRange multiIndex,
4716+ ArrayRef<int64_t > basis, bool disjoint) {
4717+ build (odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4718+ }
4719+
4720+ LogicalResult AffineLinearizeIndexOp::verify () {
4721+ if (getStaticBasis ().empty ())
4722+ return emitOpError (" basis should not be empty" );
4723+
4724+ if (getMultiIndex ().size () != getStaticBasis ().size ())
4725+ return emitOpError (" should be passed an index for each basis element" );
4726+
4727+ auto dynamicMarkersCount =
4728+ llvm::count_if (getStaticBasis (), ShapedType::isDynamic);
4729+ if (static_cast <size_t >(dynamicMarkersCount) != getDynamicBasis ().size ())
4730+ return emitOpError (
4731+ " mismatch between dynamic and static basis (kDynamic marker but no "
4732+ " corresponding dynamic basis entry) -- this can only happen due to an "
4733+ " incorrect fold/rewrite" );
4734+
4735+ return success ();
4736+ }
4737+
4738+ namespace {
4739+ // / Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
4740+ // / %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
4741+ // / %...d)`.
4742+
4743+ // / Note that `disjoint` is required here, because, without it, we could have
4744+ // / `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
4745+ // / is a valid operation where the `%c64` cannot be trivially dropped.
4746+ // /
4747+ // / Alternatively, if `%x` in the above is a known constant 0, remove it even if
4748+ // / the operation isn't asserted to be `disjoint`.
4749+ struct DropLinearizeUnitComponentsIfDisjointOrZero final
4750+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
4751+ using OpRewritePattern::OpRewritePattern;
4752+
4753+ LogicalResult matchAndRewrite (affine::AffineLinearizeIndexOp op,
4754+ PatternRewriter &rewriter) const override {
4755+ size_t numIndices = op.getMultiIndex ().size ();
4756+ SmallVector<Value> newIndices;
4757+ newIndices.reserve (numIndices);
4758+ SmallVector<OpFoldResult> newBasis;
4759+ newBasis.reserve (numIndices);
4760+
4761+ SmallVector<OpFoldResult> basis = op.getMixedBasis ();
4762+ for (auto [index, basisElem] : llvm::zip_equal (op.getMultiIndex (), basis)) {
4763+ std::optional<int64_t > basisEntry = getConstantIntValue (basisElem);
4764+ if (!basisEntry || *basisEntry != 1 ) {
4765+ newIndices.push_back (index);
4766+ newBasis.push_back (basisElem);
4767+ continue ;
4768+ }
4769+
4770+ std::optional<int64_t > indexValue = getConstantIntValue (index);
4771+ if (!op.getDisjoint () && (!indexValue || *indexValue != 0 )) {
4772+ newIndices.push_back (index);
4773+ newBasis.push_back (basisElem);
4774+ continue ;
4775+ }
4776+ }
4777+ if (newIndices.size () == numIndices)
4778+ return failure ();
4779+
4780+ if (newIndices.size () == 0 ) {
4781+ rewriter.replaceOpWithNewOp <arith::ConstantIndexOp>(op, 0 );
4782+ return success ();
4783+ }
4784+ rewriter.replaceOpWithNewOp <affine::AffineLinearizeIndexOp>(
4785+ op, newIndices, newBasis, op.getDisjoint ());
4786+ return success ();
4787+ }
4788+ };
4789+ } // namespace
4790+
4791+ void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns (
4792+ RewritePatternSet &patterns, MLIRContext *context) {
4793+ patterns.add <DropLinearizeUnitComponentsIfDisjointOrZero>(context);
4794+ }
4795+
46874796// ===----------------------------------------------------------------------===//
46884797// TableGen'd op method definitions
46894798// ===----------------------------------------------------------------------===//
0 commit comments