@@ -45,27 +45,40 @@ using namespace mlir;
4545#define DBGSNL () (llvm::dbgs() << " \n " )
4646#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
4747
48- // / Returns a compressed mask. The mask value is set only if any mask is present
49- // / in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
50- // / equals to 1 (intraDataOffset strictly smaller than scale), the following
51- // / mask:
48+ // / Returns a compressed mask for the emulated vector. For example, when
49+ // / emulating an eight-element `i8` vector with `i32` (i.e. when the source
50+ // / elements span two dest elements), this method compresses `vector<8xi1>`
51+ // / into `vector<2xi1>`.
52+ // /
53+ // / The compressed/output mask value is set iff any mask in the corresponding
54+ // / `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
55+ // / `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
56+ // / following mask:
5257// /
5358// / %mask = [1, 1, 0, 0, 0, 0]
5459// /
55- // / will first be padded in the front with number of `intraDataOffset` zeros,
56- // / and pad zeros in the back to make the number of elements a multiple of
57- // / `scale` (just to make it easier to compute). The new mask will be:
60+ // / will first be padded in the front with `numFrontPadElems` zeros, and zeros
61+ // / will be added in the back to make the number of elements a multiple of
62+ // / `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
63+ // /
5864// / %mask = [0, 1, 1, 0, 0, 0, 0, 0]
5965// /
6066// / then it will return the following new compressed mask:
6167// /
6268// / %mask = [1, 1, 0, 0]
69+ // /
70+ // / NOTE: `numFrontPadElems` is assumed to be strictly smaller than
71+ // / `numSrcElemsPerDest`.
6372static FailureOr<Operation *> getCompressedMaskOp (OpBuilder &rewriter,
6473 Location loc, Value mask,
65- int origElements, int scale,
66- int intraDataOffset = 0 ) {
67- assert (intraDataOffset < scale && " intraDataOffset must be less than scale" );
68- auto numElements = llvm::divideCeil (intraDataOffset + origElements, scale);
74+ int numSrcElems,
75+ int numSrcElemsPerDest,
76+ int numFrontPadElems = 0 ) {
77+
78+ assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
79+
80+ auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81+ numSrcElemsPerDest;
6982
7083 Operation *maskOp = mask.getDefiningOp ();
7184 SmallVector<vector::ExtractOp, 2 > extractOps;
@@ -93,8 +106,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
93106 size_t numMaskOperands = maskOperands.size ();
94107 AffineExpr s0;
95108 bindSymbols (rewriter.getContext (), s0);
96- s0 = s0 + scale - 1 ;
97- s0 = s0.floorDiv (scale );
109+ s0 = s0 + numSrcElemsPerDest - 1 ;
110+ s0 = s0.floorDiv (numSrcElemsPerDest );
98111 OpFoldResult origIndex =
99112 getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
100113 OpFoldResult maskIndex =
@@ -108,18 +121,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
108121 ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
109122 size_t numMaskOperands = maskDimSizes.size ();
110123 int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
111- int64_t startIndex = intraDataOffset / scale;
112- int64_t maskIndex = llvm::divideCeil (intraDataOffset + origIndex, scale);
124+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125+ int64_t maskIndex =
126+ llvm::divideCeil (numFrontPadElems + origIndex, numSrcElemsPerDest);
113127
114128 // TODO: we only want the mask between [startIndex, maskIndex] to be true,
115129 // the rest are false.
116- if (intraDataOffset != 0 && maskDimSizes.size () > 1 )
130+ if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
117131 return failure ();
118132
119133 SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
120134 newMaskDimSizes.push_back (maskIndex);
121135
122- if (intraDataOffset == 0 ) {
136+ if (numFrontPadElems == 0 ) {
123137 newMask = rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
124138 newMaskDimSizes);
125139 } else {
0 commit comments