diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index eb4ce24548e60..e5f2a847994ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -45,27 +45,40 @@ using namespace mlir; #define DBGSNL() (llvm::dbgs() << "\n") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") -/// Returns a compressed mask. The mask value is set only if any mask is present -/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset` -/// equals to 1 (intraDataOffset strictly smaller than scale), the following -/// mask: +/// Returns a compressed mask for the emulated vector. For example, when +/// emulating an eight-element `i8` vector with `i32` (i.e. when the source +/// elements span two dest elements), this method compresses `vector<8xi1>` +/// into `vector<2xi1>`. +/// +/// The compressed/output mask value is set iff any mask in the corresponding +/// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if +/// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the +/// following mask: /// /// %mask = [1, 1, 0, 0, 0, 0] /// -/// will first be padded in the front with number of `intraDataOffset` zeros, -/// and pad zeros in the back to make the number of elements a multiple of -/// `scale` (just to make it easier to compute). The new mask will be: +/// will first be padded in the front with `numFrontPadElems` zeros, and zeros +/// will be added in the back to make the number of elements a multiple of +/// `numSrcElemsPerDest` (for easier computation). The resulting mask will be: +/// /// %mask = [0, 1, 1, 0, 0, 0, 0, 0] /// /// then it will return the following new compressed mask: /// /// %mask = [1, 1, 0, 0] +/// +/// NOTE: `numFrontPadElems` is assumed to be strictly smaller than +/// `numSrcElemsPerDest`. static FailureOr getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, - int origElements, int scale, - int intraDataOffset = 0) { - assert(intraDataOffset < scale && "intraDataOffset must be less than scale"); - auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale); + int numSrcElems, + int numSrcElemsPerDest, + int numFrontPadElems = 0) { + + assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale"); + + auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) / + numSrcElemsPerDest; Operation *maskOp = mask.getDefiningOp(); SmallVector extractOps; @@ -93,8 +106,8 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, size_t numMaskOperands = maskOperands.size(); AffineExpr s0; bindSymbols(rewriter.getContext(), s0); - s0 = s0 + scale - 1; - s0 = s0.floorDiv(scale); + s0 = s0 + numSrcElemsPerDest - 1; + s0 = s0.floorDiv(numSrcElemsPerDest); OpFoldResult origIndex = getAsOpFoldResult(maskOperands[numMaskOperands - 1]); OpFoldResult maskIndex = @@ -108,18 +121,19 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); size_t numMaskOperands = maskDimSizes.size(); int64_t origIndex = maskDimSizes[numMaskOperands - 1]; - int64_t startIndex = intraDataOffset / scale; - int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale); + int64_t startIndex = numFrontPadElems / numSrcElemsPerDest; + int64_t maskIndex = + llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest); // TODO: we only want the mask between [startIndex, maskIndex] to be true, // the rest are false. - if (intraDataOffset != 0 && maskDimSizes.size() > 1) + if (numFrontPadElems != 0 && maskDimSizes.size() > 1) return failure(); SmallVector newMaskDimSizes(maskDimSizes.drop_back()); newMaskDimSizes.push_back(maskIndex); - if (intraDataOffset == 0) { + if (numFrontPadElems == 0) { newMask = rewriter.create(loc, newMaskType, newMaskDimSizes); } else {