-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][affine] Add static basis support to affine.delinearize #113846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| #include "mlir/Dialect/Affine/IR/AffineValueMap.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/UB/IR/UBOps.h" | ||
| #include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
| #include "mlir/IR/AffineExprVisitor.h" | ||
| #include "mlir/IR/IRMapping.h" | ||
| #include "mlir/IR/IntegerSet.h" | ||
|
|
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes( | |
| RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { | ||
| AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties, | ||
| regions); | ||
| inferredReturnTypes.assign(adaptor.getBasis().size(), | ||
| inferredReturnTypes.assign(adaptor.getStaticBasis().size(), | ||
| IndexType::get(context)); | ||
| return success(); | ||
| } | ||
|
|
||
| void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, | ||
| void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, | ||
| OperationState &odsState, | ||
| Value linearIndex, ValueRange basis) { | ||
| SmallVector<Value> dynamicBasis; | ||
| SmallVector<int64_t> staticBasis; | ||
| dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, | ||
| staticBasis); | ||
| build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis); | ||
| } | ||
|
|
||
| void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, | ||
| OperationState &odsState, | ||
| Value linearIndex, | ||
| ArrayRef<OpFoldResult> basis) { | ||
| result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType())); | ||
| result.addOperands(linearIndex); | ||
| SmallVector<Value> basisValues = | ||
| llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value { | ||
| std::optional<int64_t> staticDim = getConstantIntValue(ofr); | ||
| if (staticDim.has_value()) | ||
| return builder.create<arith::ConstantIndexOp>(result.location, | ||
| *staticDim); | ||
| return llvm::dyn_cast_if_present<Value>(ofr); | ||
| }); | ||
| result.addOperands(basisValues); | ||
| SmallVector<Value> dynamicBasis; | ||
| SmallVector<int64_t> staticBasis; | ||
| dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); | ||
| build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis); | ||
| } | ||
|
|
||
| void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, | ||
| OperationState &odsState, | ||
| Value linearIndex, | ||
| ArrayRef<int64_t> basis) { | ||
| build(odsBuilder, odsState, linearIndex, ValueRange{}, basis); | ||
| } | ||
|
|
||
| LogicalResult AffineDelinearizeIndexOp::verify() { | ||
| if (getBasis().empty()) | ||
| if (getStaticBasis().empty()) | ||
| return emitOpError("basis should not be empty"); | ||
| if (getNumResults() != getBasis().size()) | ||
| if (getNumResults() != getStaticBasis().size()) | ||
| return emitOpError("should return an index for each basis element"); | ||
| auto dynamicMarkersCount = | ||
| llvm::count_if(getStaticBasis(), ShapedType::isDynamic); | ||
| if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size()) | ||
| return emitOpError( | ||
| "mismatch between dynamic and static basis (kDynamic marker but no " | ||
| "corresponding dynamic basis entry) -- this can only happen due to an " | ||
| "incorrect fold/rewrite"); | ||
|
Comment on lines
+4554
to
+4555
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not construct such an op using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The builder construction is along the lines of |
||
| return success(); | ||
| } | ||
|
|
||
|
|
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis | |
|
|
||
| // Replace all indices corresponding to unit-extent basis with 0. | ||
| // Remaining basis can be used to get a new `affine.delinearize_index` op. | ||
| SmallVector<Value> newOperands; | ||
| for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) { | ||
| if (matchPattern(basis, m_One())) | ||
| SmallVector<OpFoldResult> newOperands; | ||
| for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) { | ||
| std::optional<int64_t> basisVal = getConstantIntValue(basis); | ||
| if (basisVal && *basisVal == 1) | ||
| replacements[index] = getZero(); | ||
| else | ||
| newOperands.push_back(basis); | ||
| } | ||
|
|
||
| if (newOperands.size() == delinearizeOp.getBasis().size()) | ||
| if (newOperands.size() == delinearizeOp.getStaticBasis().size()) | ||
| return failure(); | ||
|
|
||
| if (!newOperands.empty()) { | ||
|
|
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop | |
|
|
||
| LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto basis = delinearizeOp.getBasis(); | ||
| if (basis.size() != 1) | ||
| if (delinearizeOp.getStaticBasis().size() != 1) | ||
| return failure(); | ||
| auto basis = delinearizeOp.getMixedBasis(); | ||
|
|
||
| // Check that the `linear_index` is an induction variable. | ||
| auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex()); | ||
|
|
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop | |
| // Check that the upper-bound is the basis. | ||
| auto upperBounds = loopLikeOp.getLoopUpperBounds(); | ||
| if (!upperBounds || upperBounds->size() != 1 || | ||
| upperBounds->front() != getAsOpFoldResult(basis.front())) { | ||
| upperBounds->front() != basis.front()) { | ||
| return rewriter.notifyMatchFailure(delinearizeOp, | ||
| "`basis` is not upper bound"); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.