@@ -4664,6 +4664,112 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
46644664 patterns.insert <DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
46654665}
46664666
4667+ // ===----------------------------------------------------------------------===//
4668+ // LinearizeIndexOp
4669+ // ===----------------------------------------------------------------------===//
4670+
4671+ void AffineLinearizeIndexOp::build (OpBuilder &odsBuilder,
4672+ OperationState &odsState,
4673+ ValueRange multiIndex, ValueRange basis,
4674+ bool disjoint) {
4675+ SmallVector<Value> dynamicBasis;
4676+ SmallVector<int64_t > staticBasis;
4677+ dispatchIndexOpFoldResults (getAsOpFoldResult (basis), dynamicBasis,
4678+ staticBasis);
4679+ build (odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4680+ }
4681+
4682+ void AffineLinearizeIndexOp::build (OpBuilder &odsBuilder,
4683+ OperationState &odsState,
4684+ ValueRange multiIndex,
4685+ ArrayRef<OpFoldResult> basis,
4686+ bool disjoint) {
4687+ SmallVector<Value> dynamicBasis;
4688+ SmallVector<int64_t > staticBasis;
4689+ dispatchIndexOpFoldResults (basis, dynamicBasis, staticBasis);
4690+ build (odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4691+ }
4692+
4693+ void AffineLinearizeIndexOp::build (OpBuilder &odsBuilder,
4694+ OperationState &odsState,
4695+ ValueRange multiIndex,
4696+ ArrayRef<int64_t > basis, bool disjoint) {
4697+ build (odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4698+ }
4699+
4700+ LogicalResult AffineLinearizeIndexOp::verify () {
4701+ if (getStaticBasis ().empty ())
4702+ return emitOpError (" basis should not be empty" );
4703+ if (getMultiIndex ().size () != getStaticBasis ().size ())
4704+ return emitOpError (" should be passed an index for each basis element" );
4705+ auto dynamicMarkersCount =
4706+ llvm::count_if (getStaticBasis (), ShapedType::isDynamic);
4707+ if (static_cast <size_t >(dynamicMarkersCount) != getDynamicBasis ().size ())
4708+ return emitOpError (
4709+ " mismatch between dynamic and static basis (kDynamic marker but no "
4710+ " corresponding dynamic basis entry) -- this can only happen due to an "
4711+ " incorrect fold/rewrite" );
4712+ return success ();
4713+ }
4714+
4715+ namespace {
4716+ // / Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
4717+ // / %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
4718+ // / %...d)`.
4719+
4720+ // / Note that `disjoint` is required here, because, without it, we could have
4721+ // / `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
4722+ // / is a valid operation where the `%c64` cannot be trivially dropped.
4723+ // /
4724+ // / Alternatively, if `%x` in the above is a known constant 0, remove it even if
4725+ // / the operation isn't asserted to be `disjoint`.
4726+ struct DropLinearizeUnitComponentsIfDisjointOrZero
4727+ : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
4728+ using OpRewritePattern::OpRewritePattern;
4729+
4730+ LogicalResult matchAndRewrite (affine::AffineLinearizeIndexOp op,
4731+ PatternRewriter &rewriter) const override {
4732+ size_t numIndices = op.getMultiIndex ().size ();
4733+ SmallVector<Value> newIndices;
4734+ newIndices.reserve (numIndices);
4735+ SmallVector<OpFoldResult> newBasis;
4736+ newBasis.reserve (numIndices);
4737+
4738+ SmallVector<OpFoldResult> basis = op.getMixedBasis ();
4739+ for (auto [index, basisElem] : llvm::zip_equal (op.getMultiIndex (), basis)) {
4740+ std::optional<int64_t > basisEntry = getConstantIntValue (basisElem);
4741+ if (!basisEntry || *basisEntry != 1 ) {
4742+ newIndices.push_back (index);
4743+ newBasis.push_back (basisElem);
4744+ continue ;
4745+ }
4746+
4747+ std::optional<int64_t > indexValue = getConstantIntValue (index);
4748+ if (!op.getDisjoint () && (!indexValue || *indexValue != 0 )) {
4749+ newIndices.push_back (index);
4750+ newBasis.push_back (basisElem);
4751+ continue ;
4752+ }
4753+ }
4754+ if (newIndices.size () == numIndices)
4755+ return failure ();
4756+
4757+ if (newIndices.size () == 0 ) {
4758+ rewriter.replaceOpWithNewOp <arith::ConstantIndexOp>(op, 0 );
4759+ return success ();
4760+ }
4761+ rewriter.replaceOpWithNewOp <affine::AffineLinearizeIndexOp>(
4762+ op, newIndices, newBasis, op.getDisjoint ());
4763+ return success ();
4764+ }
4765+ };
4766+ } // namespace
4767+
4768+ void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns (
4769+ RewritePatternSet &patterns, MLIRContext *context) {
4770+ patterns.add <DropLinearizeUnitComponentsIfDisjointOrZero>(context);
4771+ }
4772+
46674773// ===----------------------------------------------------------------------===//
46684774// TableGen'd op method definitions
46694775// ===----------------------------------------------------------------------===//
0 commit comments