-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][affine] Add static basis support to affine.delinearize #113846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| #include "mlir/Dialect/Affine/IR/AffineValueMap.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/UB/IR/UBOps.h" | ||
| #include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
| #include "mlir/IR/AffineExprVisitor.h" | ||
| #include "mlir/IR/IRMapping.h" | ||
| #include "mlir/IR/IntegerSet.h" | ||
|
|
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes( | |
| RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { | ||
| AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties, | ||
| regions); | ||
| inferredReturnTypes.assign(adaptor.getBasis().size(), | ||
| inferredReturnTypes.assign(adaptor.getStaticBasis().size(), | ||
| IndexType::get(context)); | ||
| return success(); | ||
| } | ||
|
|
||
| void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, | ||
| void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, | ||
| OperationState &odsState, | ||
| Value linearIndex, ValueRange basis) { | ||
| SmallVector<Value> dynamicBasis; | ||
| SmallVector<int64_t> staticBasis; | ||
| dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, | ||
| staticBasis); | ||
| build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis); | ||
| } | ||
|
|
||
| void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, | ||
| OperationState &odsState, | ||
| Value linearIndex, | ||
| ArrayRef<OpFoldResult> basis) { | ||
| result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType())); | ||
| result.addOperands(linearIndex); | ||
| SmallVector<Value> basisValues = | ||
| llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value { | ||
| std::optional<int64_t> staticDim = getConstantIntValue(ofr); | ||
| if (staticDim.has_value()) | ||
| return builder.create<arith::ConstantIndexOp>(result.location, | ||
| *staticDim); | ||
| return llvm::dyn_cast_if_present<Value>(ofr); | ||
| }); | ||
| result.addOperands(basisValues); | ||
| SmallVector<Value> dynamicBasis; | ||
| SmallVector<int64_t> staticBasis; | ||
| dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); | ||
| build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis); | ||
| } | ||
|
|
||
| void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder, | ||
| OperationState &odsState, | ||
| Value linearIndex, | ||
| ArrayRef<int64_t> basis) { | ||
| build(odsBuilder, odsState, linearIndex, ValueRange{}, basis); | ||
| } | ||
|
|
||
| LogicalResult AffineDelinearizeIndexOp::verify() { | ||
| if (getBasis().empty()) | ||
| if (getStaticBasis().empty()) | ||
| return emitOpError("basis should not be empty"); | ||
| if (getNumResults() != getBasis().size()) | ||
| if (getNumResults() != getStaticBasis().size()) | ||
| return emitOpError("should return an index for each basis element"); | ||
| auto dynamicMarkersCount = | ||
| llvm::count_if(getStaticBasis(), ShapedType::isDynamic); | ||
| if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size()) | ||
| return emitOpError( | ||
| "mismatch between dynamic and static basis (kDynamic marker but no " | ||
| "corresponding dynamic basis entry) -- this can only happen due to an " | ||
| "incorrect fold/rewrite"); | ||
|
Comment on lines
+4554
to
+4555
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not construct such an op using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The builder construction is along the lines of |
||
| return success(); | ||
| } | ||
|
|
||
|
|
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis | |
|
|
||
| // Replace all indices corresponding to unit-extent basis with 0. | ||
| // Remaining basis can be used to get a new `affine.delinearize_index` op. | ||
| SmallVector<Value> newOperands; | ||
| for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) { | ||
| if (matchPattern(basis, m_One())) | ||
| SmallVector<OpFoldResult> newOperands; | ||
| for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) { | ||
| std::optional<int64_t> basisVal = getConstantIntValue(basis); | ||
| if (basisVal && *basisVal == 1) | ||
| replacements[index] = getZero(); | ||
| else | ||
| newOperands.push_back(basis); | ||
| } | ||
|
|
||
| if (newOperands.size() == delinearizeOp.getBasis().size()) | ||
| if (newOperands.size() == delinearizeOp.getStaticBasis().size()) | ||
| return failure(); | ||
|
|
||
| if (!newOperands.empty()) { | ||
|
|
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop | |
|
|
||
| LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto basis = delinearizeOp.getBasis(); | ||
| if (basis.size() != 1) | ||
| if (delinearizeOp.getStaticBasis().size() != 1) | ||
| return failure(); | ||
| auto basis = delinearizeOp.getMixedBasis(); | ||
|
|
||
| // Check that the `linear_index` is an induction variable. | ||
| auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex()); | ||
|
|
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop | |
| // Check that the upper-bound is the basis. | ||
| auto upperBounds = loopLikeOp.getLoopUpperBounds(); | ||
| if (!upperBounds || upperBounds->size() != 1 || | ||
| upperBounds->front() != getAsOpFoldResult(basis.front())) { | ||
| upperBounds->front() != basis.front()) { | ||
| return rewriter.notifyMatchFailure(delinearizeOp, | ||
| "`basis` is not upper bound"); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc, | |
| return result; | ||
| } | ||
|
|
||
| static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc, | ||
|
||
| ArrayRef<OpFoldResult> set) { | ||
| if (set.empty()) | ||
| return failure(); | ||
| OpFoldResult result = set[0]; | ||
| AffineExpr s0, s1; | ||
| bindSymbols(b.getContext(), s0, s1); | ||
| for (unsigned i = 1, e = set.size(); i < e; i++) | ||
| result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]}); | ||
| return result; | ||
| } | ||
|
|
||
| FailureOr<SmallVector<Value>> | ||
| mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, | ||
| ArrayRef<Value> basis) { | ||
|
|
@@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, | |
| return results; | ||
| } | ||
|
|
||
| FailureOr<SmallVector<Value>> | ||
| mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, | ||
| ArrayRef<OpFoldResult> basis) { | ||
| unsigned numDims = basis.size(); | ||
|
|
||
| SmallVector<Value> divisors; | ||
| for (unsigned i = 1; i < numDims; i++) { | ||
| ArrayRef<OpFoldResult> slice = basis.drop_front(i); | ||
| FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice); | ||
| if (failed(prod)) | ||
| return failure(); | ||
| divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be more efficient to do a scan and collect intermediate products than rerun the product every time.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done (and applied to the existing function), thanks! |
||
|
|
||
| SmallVector<Value> results; | ||
| results.reserve(divisors.size() + 1); | ||
| Value residual = linearIndex; | ||
| for (Value divisor : divisors) { | ||
| DivModValue divMod = getDivMod(b, loc, residual, divisor); | ||
| results.push_back(divMod.quotient); | ||
| residual = divMod.remainder; | ||
| } | ||
| results.push_back(residual); | ||
| return results; | ||
| } | ||
|
|
||
| OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex, | ||
| ArrayRef<OpFoldResult> basis, | ||
| ImplicitLocOpBuilder &builder) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.