Skip to content

Commit ce112a7

Browse files
lialanhanhanW
andauthored
[MLIR] support dynamic indexing in VectorEmulateNarrowTypes (#114169)
* Supports `vector.load` and `vector.transfer_read` ops. * In the case of dynamic indexing, use per-element insertion/extraction to build desired narrow type vectors. * Fixed wrong function comment of `getCompressedMaskOp`. --------- Co-authored-by: Han-Chung Wang <[email protected]>
1 parent 9540a7a commit ce112a7

File tree

2 files changed

+254
-75
lines changed

2 files changed

+254
-75
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 106 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/OpDefinition.h"
2122
#include "mlir/IR/TypeUtilities.h"
2223
#include "mlir/IR/Value.h"
2324
#include "mlir/Transforms/DialectConversion.h"
@@ -37,16 +38,17 @@ using namespace mlir;
3738

3839
/// Returns a compressed mask. The mask value is set only if any mask is present
3940
/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
40-
/// equals to 2, the following mask:
41+
/// equals to 1 (intraDataOffset strictly smaller than scale), the following
42+
/// mask:
4143
///
42-
/// %mask = [1, 1, 1, 0, 0, 0]
44+
/// %mask = [1, 1, 0, 0, 0, 0]
4345
///
4446
/// will first be padded with number of `intraDataOffset` zeros:
45-
/// %mask = [0, 0, 1, 1, 1, 0, 0, 0]
47+
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
4648
///
4749
/// then it will return the following new compressed mask:
4850
///
49-
/// %mask = [0, 1, 1, 0]
51+
/// %mask = [1, 1, 0, 0]
5052
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
5153
Location loc, Value mask,
5254
int origElements, int scale,
@@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7577
shape.back() = numElements;
7678
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
7779
if (createMaskOp) {
78-
// TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
79-
if (intraDataOffset != 0)
80-
return failure();
8180
OperandRange maskOperands = createMaskOp.getOperands();
8281
size_t numMaskOperands = maskOperands.size();
8382
AffineExpr s0;
@@ -129,26 +128,79 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
129128
return newMask;
130129
}
131130

132-
static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
133-
VectorType extractType, Value vector,
134-
int64_t frontOffset, int64_t subvecSize) {
131+
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
132+
/// emitting `vector.extract_strided_slice`.
133+
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
134+
VectorType extractType, Value source,
135+
int64_t frontOffset,
136+
int64_t subvecSize) {
137+
auto vectorType = cast<VectorType>(source.getType());
138+
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
139+
"expected 1-D source and destination types");
135140
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
136141
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
137142
auto strides = rewriter.getI64ArrayAttr({1});
138143
return rewriter
139-
.create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
144+
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
140145
sizes, strides)
141146
->getResult(0);
142147
}
143148

144-
static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
145-
Value src, Value dest, int64_t offset) {
149+
/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
150+
/// at `offset`. it is a wrapper function for emitting
151+
/// `vector.insert_strided_slice`.
152+
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
153+
Value src, Value dest, int64_t offset) {
154+
auto srcType = cast<VectorType>(src.getType());
155+
auto destType = cast<VectorType>(dest.getType());
156+
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
157+
"expected source and dest to be vector type");
146158
auto offsets = rewriter.getI64ArrayAttr({offset});
147159
auto strides = rewriter.getI64ArrayAttr({1});
148160
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
149161
dest, offsets, strides);
150162
}
151163

164+
/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
165+
/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
166+
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
167+
/// use it when `offset` cannot be folded into a constant value.
168+
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
169+
TypedValue<VectorType> source,
170+
Value dest, OpFoldResult offset,
171+
int64_t numElementsToExtract) {
172+
for (int i = 0; i < numElementsToExtract; ++i) {
173+
Value extractLoc =
174+
(i == 0) ? offset.dyn_cast<Value>()
175+
: rewriter.create<arith::AddIOp>(
176+
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
177+
rewriter.create<arith::ConstantIndexOp>(loc, i));
178+
auto extractOp =
179+
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
180+
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
181+
}
182+
return dest;
183+
}
184+
185+
/// Returns the op sequence for an emulated sub-byte data type vector load.
186+
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
187+
/// The load location is given by `base` and `linearizedIndices`, and the
188+
/// load size is given by `numEmulatedElementsToLoad`.
189+
static TypedValue<VectorType>
190+
emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
191+
OpFoldResult linearizedIndices,
192+
int64_t numEmultedElementsToLoad, Type origElemType,
193+
Type emulatedElemType) {
194+
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
195+
origElemType.getIntOrFloatBitWidth();
196+
auto newLoad = rewriter.create<vector::LoadOp>(
197+
loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
198+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
199+
return rewriter.create<vector::BitCastOp>(
200+
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
201+
newLoad);
202+
};
203+
152204
namespace {
153205

154206
//===----------------------------------------------------------------------===//
@@ -380,25 +432,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
380432
? getConstantIntValue(linearizedInfo.intraDataOffset)
381433
: 0;
382434

383-
if (!foldedIntraVectorOffset) {
384-
// unimplemented case for dynamic intra vector offset
385-
return failure();
386-
}
387-
435+
// Always load enough elements which can cover the original elements.
436+
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
388437
auto numElements =
389-
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
390-
auto newLoad = rewriter.create<vector::LoadOp>(
391-
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
392-
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
393-
394-
Value result = rewriter.create<vector::BitCastOp>(
395-
loc, VectorType::get(numElements * scale, oldElementType), newLoad);
396-
397-
if (isUnalignedEmulation) {
398-
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
399-
*foldedIntraVectorOffset, origElements);
438+
llvm::divideCeil(maxintraDataOffset + origElements, scale);
439+
Value result =
440+
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
441+
numElements, oldElementType, newElementType);
442+
443+
if (foldedIntraVectorOffset) {
444+
if (isUnalignedEmulation) {
445+
result =
446+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
447+
*foldedIntraVectorOffset, origElements);
448+
}
449+
} else {
450+
auto resultVector = rewriter.create<arith::ConstantOp>(
451+
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
452+
result = dynamicallyExtractSubVector(
453+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
454+
linearizedInfo.intraDataOffset, origElements);
400455
}
401-
402456
rewriter.replaceOp(op, result);
403457
return success();
404458
}
@@ -513,8 +567,8 @@ struct ConvertVectorMaskedLoad final
513567
// create an empty vector of the new type
514568
auto emptyVector = rewriter.create<arith::ConstantOp>(
515569
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
516-
passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
517-
*foldedIntraVectorOffset);
570+
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
571+
*foldedIntraVectorOffset);
518572
}
519573
auto newPassThru =
520574
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -537,16 +591,17 @@ struct ConvertVectorMaskedLoad final
537591
// TODO: can fold if op's mask is constant
538592
auto emptyVector = rewriter.create<arith::ConstantOp>(
539593
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
540-
mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
541-
*foldedIntraVectorOffset);
594+
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
595+
*foldedIntraVectorOffset);
542596
}
543597

544598
Value result =
545599
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
546600

547601
if (isUnalignedEmulation) {
548-
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
549-
*foldedIntraVectorOffset, origElements);
602+
result =
603+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
604+
*foldedIntraVectorOffset, origElements);
550605
}
551606
rewriter.replaceOp(op, result);
552607

@@ -604,13 +659,10 @@ struct ConvertVectorTransferRead final
604659
? getConstantIntValue(linearizedInfo.intraDataOffset)
605660
: 0;
606661

607-
if (!foldedIntraVectorOffset) {
608-
// unimplemented case for dynamic inra-vector offset
609-
return failure();
610-
}
611-
662+
auto maxIntraVectorOffset =
663+
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
612664
auto numElements =
613-
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
665+
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
614666

615667
auto newRead = rewriter.create<vector::TransferReadOp>(
616668
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +673,18 @@ struct ConvertVectorTransferRead final
621673
loc, VectorType::get(numElements * scale, oldElementType), newRead);
622674

623675
Value result = bitCast->getResult(0);
624-
if (isUnalignedEmulation) {
625-
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
626-
*foldedIntraVectorOffset, origElements);
676+
if (foldedIntraVectorOffset) {
677+
if (isUnalignedEmulation) {
678+
result =
679+
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
680+
*foldedIntraVectorOffset, origElements);
681+
}
682+
} else {
683+
auto zeros = rewriter.create<arith::ConstantOp>(
684+
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
685+
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
686+
linearizedInfo.intraDataOffset,
687+
origElements);
627688
}
628689
rewriter.replaceOp(op, result);
629690

0 commit comments

Comments
 (0)