Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,40 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

/// Returns a compressed mask. The mask value is set only if any mask is present
/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
/// equals to 1 (intraDataOffset strictly smaller than scale), the following
/// mask:
/// Returns a compressed mask for the emulated vector. For example, when
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
/// elements span two dest elements), this method compresses `vector<8xi1>`
/// into `vector<2xi1>`.
///
/// The compressed/output mask value is set iff any mask in the corresponding
/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
/// following mask:
///
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded in the front with number of `intraDataOffset` zeros,
/// and pad zeros in the back to make the number of elements a multiple of
/// `scale` (just to make it easier to compute). The new mask will be:
/// will first be padded in the front with `numFrontPadElems` zeros, and zeros
/// will be added in the back to make the number of elements a multiple of
/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
///
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
///
/// %mask = [1, 1, 0, 0]
///
/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
/// `numSrcElemsPerDest`.
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale);
int numSrcElems,
int numSrcElemsPerDest,
int numFrontPadElems = 0) {

assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");

auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;

Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
Expand Down Expand Up @@ -93,8 +106,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
s0 = s0 + scale - 1;
s0 = s0.floorDiv(scale);
s0 = s0 + numSrcElemsPerDest - 1;
s0 = s0.floorDiv(numSrcElemsPerDest);
OpFoldResult origIndex =
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
OpFoldResult maskIndex =
Expand All @@ -108,18 +121,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
int64_t startIndex = intraDataOffset / scale;
int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
int64_t maskIndex =
llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);

// TODO: we only want the mask between [startIndex, maskIndex] to be true,
// the rest are false.
if (intraDataOffset != 0 && maskDimSizes.size() > 1)
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
return failure();

SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);

if (intraDataOffset == 0) {
if (numFrontPadElems == 0) {
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
} else {
Expand Down
Loading