-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] support dynamic indexing in VectorEmulateNarrowTypes
#114169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
2effa6a
2580b46
0b9bdce
a9d7260
b777a60
fc29242
2b86a23
8225f72
ce75839
1c6f3c0
74840ce
80c313b
db1c38e
71583fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" | ||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/OpDefinition.h" | ||
| #include "mlir/IR/TypeUtilities.h" | ||
| #include "mlir/IR/Value.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
|
|
@@ -37,16 +38,17 @@ using namespace mlir; | |
|
|
||
| /// Returns a compressed mask. The mask value is set only if any mask is present | ||
| /// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset` | ||
| /// equals to 2, the following mask: | ||
| /// equals to 1 (intraDataOffset strictly smaller than scale), the following | ||
| /// mask: | ||
| /// | ||
| /// %mask = [1, 1, 1, 0, 0, 0] | ||
| /// %mask = [1, 1, 0, 0, 0, 0] | ||
| /// | ||
| /// will first be padded with number of `intraDataOffset` zeros: | ||
| /// %mask = [0, 0, 1, 1, 1, 0, 0, 0] | ||
| /// %mask = [0, 1, 1, 0, 0, 0, 0, 0] | ||
| /// | ||
| /// then it will return the following new compressed mask: | ||
| /// | ||
| /// %mask = [0, 1, 1, 0] | ||
| /// %mask = [1, 1, 0, 0] | ||
| static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, | ||
| Location loc, Value mask, | ||
| int origElements, int scale, | ||
|
|
@@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, | |
| shape.back() = numElements; | ||
| auto newMaskType = VectorType::get(shape, rewriter.getI1Type()); | ||
| if (createMaskOp) { | ||
| // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp. | ||
| if (intraDataOffset != 0) | ||
| return failure(); | ||
lialan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| OperandRange maskOperands = createMaskOp.getOperands(); | ||
| size_t numMaskOperands = maskOperands.size(); | ||
| AffineExpr s0; | ||
|
|
@@ -129,9 +128,17 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, | |
| return newMask; | ||
| } | ||
|
|
||
| /// A wrapper function for emitting `vector.extract_strided_slice`. The vector | ||
| /// has to be of 1-D shape. | ||
|
||
| static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc, | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| VectorType extractType, Value vector, | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| int64_t frontOffset, int64_t subvecSize) { | ||
| auto vectorType = dyn_cast<VectorType>(vector.getType()); | ||
| assert(vectorType && "expected vector type"); | ||
| assert(vectorType.getShape().size() == 1 && "expected 1-D vector type"); | ||
| assert(extractType.getShape().size() == 1 && | ||
| "extractType must be 1-D vector type"); | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| auto offsets = rewriter.getI64ArrayAttr({frontOffset}); | ||
| auto sizes = rewriter.getI64ArrayAttr({subvecSize}); | ||
| auto strides = rewriter.getI64ArrayAttr({1}); | ||
|
|
@@ -141,14 +148,61 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc, | |
| ->getResult(0); | ||
| } | ||
|
|
||
| /// A wrapper function for emitting `vector.insert_strided_slice`. The source | ||
| /// and dest vectors must be of 1-D shape. | ||
| static Value insertSubvectorInto(RewriterBase &rewriter, Location loc, | ||
| Value src, Value dest, int64_t offset) { | ||
| auto srcType = dyn_cast<VectorType>(src.getType()); | ||
| assert(srcType && "expected vector type"); | ||
| assert(srcType.getShape().size() == 1 && "expected 1-D vector type"); | ||
| auto destType = dyn_cast<VectorType>(dest.getType()); | ||
| assert(destType && "expected vector type"); | ||
| assert(destType.getShape().size() == 1 && "expected 1-D vector type"); | ||
|
|
||
| auto offsets = rewriter.getI64ArrayAttr({offset}); | ||
| auto strides = rewriter.getI64ArrayAttr({1}); | ||
| return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src, | ||
| dest, offsets, strides); | ||
| } | ||
|
|
||
| /// Extracts `lengthSubvec` elements from `srcVec` into `destVec` starting at | ||
| /// the offset specified by `srcOffsetVar`. Use this function when | ||
| /// `srcOffsetVar` is not a constant, making it impossible to use | ||
| /// vector.extract_strided_slice, as it requires constant offsets. | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc, | ||
| TypedValue<VectorType> source, | ||
| Value dest, OpFoldResult offset, | ||
| int64_t numElementsToExtract) { | ||
| for (int i = 0; i < numElementsToExtract; ++i) { | ||
| Value extractLoc = | ||
| (i == 0) ? offset.dyn_cast<Value>() | ||
| : rewriter.create<arith::AddIOp>( | ||
| loc, rewriter.getIndexType(), offset.dyn_cast<Value>(), | ||
| rewriter.create<arith::ConstantIndexOp>(loc, i)); | ||
| auto extractOp = | ||
| rewriter.create<vector::ExtractOp>(loc, source, extractLoc); | ||
| dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i); | ||
| } | ||
| return dest; | ||
| } | ||
|
|
||
| /// Load `numLoadedElements` of `newElementType` from `base` at | ||
| /// `linearizedIndices`, then bitcast the result into a vector of | ||
| /// `oldElementType`. | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| static TypedValue<VectorType> | ||
lialan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc, | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Value base, OpFoldResult linearizedIndices, | ||
| int64_t numElementsToLoad, Type oldElememtType, | ||
| Type newElementType) { | ||
| auto scale = newElementType.getIntOrFloatBitWidth() / | ||
| oldElememtType.getIntOrFloatBitWidth(); | ||
| auto newLoad = rewriter.create<vector::LoadOp>( | ||
| loc, VectorType::get(numElementsToLoad, newElementType), base, | ||
| getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); | ||
| return rewriter.create<vector::BitCastOp>( | ||
| loc, VectorType::get(numElementsToLoad * scale, oldElememtType), newLoad); | ||
| }; | ||
|
|
||
| namespace { | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -380,25 +434,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { | |
| ? getConstantIntValue(linearizedInfo.intraDataOffset) | ||
| : 0; | ||
|
|
||
| if (!foldedIntraVectorOffset) { | ||
| // unimplemented case for dynamic intra vector offset | ||
| return failure(); | ||
| } | ||
|
|
||
| // always load enough elements which can cover the original elements | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| auto maxintraDataOffset = | ||
lialan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1; | ||
| auto numElements = | ||
| llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale); | ||
| auto newLoad = rewriter.create<vector::LoadOp>( | ||
| loc, VectorType::get(numElements, newElementType), adaptor.getBase(), | ||
| getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); | ||
|
|
||
| Value result = rewriter.create<vector::BitCastOp>( | ||
| loc, VectorType::get(numElements * scale, oldElementType), newLoad); | ||
| llvm::divideCeil(maxintraDataOffset + origElements, scale); | ||
| Value result = | ||
| emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices, | ||
| numElements, oldElementType, newElementType); | ||
|
|
||
| if (isUnalignedEmulation) { | ||
| result = extractSubvectorFrom(rewriter, loc, op.getType(), result, | ||
| *foldedIntraVectorOffset, origElements); | ||
| if (foldedIntraVectorOffset) { | ||
| if (isUnalignedEmulation) { | ||
| result = extractSubvectorFrom(rewriter, loc, op.getType(), result, | ||
| *foldedIntraVectorOffset, origElements); | ||
| } | ||
| } else { | ||
| auto resultVector = rewriter.create<arith::ConstantOp>( | ||
| loc, op.getType(), rewriter.getZeroAttr(op.getType())); | ||
| result = dynamicallyExtractSubVector( | ||
| rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector, | ||
| linearizedInfo.intraDataOffset, origElements); | ||
| } | ||
|
|
||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
|
|
@@ -604,13 +660,10 @@ struct ConvertVectorTransferRead final | |
| ? getConstantIntValue(linearizedInfo.intraDataOffset) | ||
| : 0; | ||
|
|
||
| if (!foldedIntraVectorOffset) { | ||
| // unimplemented case for dynamic inra-vector offset | ||
| return failure(); | ||
| } | ||
|
|
||
| auto maxIntraVectorOffset = | ||
| foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1; | ||
| auto numElements = | ||
| llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale); | ||
| llvm::divideCeil(maxIntraVectorOffset + origElements, scale); | ||
|
|
||
| auto newRead = rewriter.create<vector::TransferReadOp>( | ||
| loc, VectorType::get(numElements, newElementType), adaptor.getSource(), | ||
|
|
@@ -621,9 +674,17 @@ struct ConvertVectorTransferRead final | |
| loc, VectorType::get(numElements * scale, oldElementType), newRead); | ||
|
|
||
| Value result = bitCast->getResult(0); | ||
| if (isUnalignedEmulation) { | ||
| result = extractSubvectorFrom(rewriter, loc, op.getType(), result, | ||
| *foldedIntraVectorOffset, origElements); | ||
| if (foldedIntraVectorOffset) { | ||
| if (isUnalignedEmulation) { | ||
| result = extractSubvectorFrom(rewriter, loc, op.getType(), result, | ||
| *foldedIntraVectorOffset, origElements); | ||
| } | ||
| } else { | ||
| auto zeros = rewriter.create<arith::ConstantOp>( | ||
| loc, op.getType(), rewriter.getZeroAttr(op.getType())); | ||
| result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, | ||
| linearizedInfo.intraDataOffset, | ||
| origElements); | ||
| } | ||
| rewriter.replaceOp(op, result); | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.