@@ -75,83 +75,133 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7575 int numSrcElemsPerDest,
7676 int numFrontPadElems = 0 ) {
7777
78- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
78+ assert (numFrontPadElems < numSrcElemsPerDest &&
79+ " numFrontPadElems must be less than numSrcElemsPerDest" );
7980
80- auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81+ auto numDestElems = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
8182 numSrcElemsPerDest;
8283
8384 Operation *maskOp = mask.getDefiningOp ();
8485 SmallVector<vector::ExtractOp, 2 > extractOps;
8586 // Finding the mask creation operation.
86- while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
87+ while (maskOp &&
88+ !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
89+ maskOp)) {
8790 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
8891 maskOp = extractOp.getVector ().getDefiningOp ();
8992 extractOps.push_back (extractOp);
9093 }
9194 }
92- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
93- auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
94- if (!createMaskOp && !constantMaskOp)
95+
96+ // TODO: add support to `vector.splat`.
97+ if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98+ maskOp))
9599 return failure ();
96100
97101 // Computing the "compressed" mask. All the emulation logic (i.e. computing
98102 // new mask index) only happens on the last dimension of the vectors.
99- Operation *newMask = nullptr ;
100- SmallVector<int64_t > shape (
103+ SmallVector<int64_t > maskShape (
101104 cast<VectorType>(maskOp->getResultTypes ()[0 ]).getShape ());
102- shape.back () = numElements;
103- auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
104- if (createMaskOp) {
105- OperandRange maskOperands = createMaskOp.getOperands ();
106- size_t numMaskOperands = maskOperands.size ();
107- AffineExpr s0;
108- bindSymbols (rewriter.getContext (), s0);
109- s0 = s0 + numSrcElemsPerDest - 1 ;
110- s0 = s0.floorDiv (numSrcElemsPerDest);
111- OpFoldResult origIndex =
112- getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
113- OpFoldResult maskIndex =
114- affine::makeComposedFoldedAffineApply (rewriter, loc, s0, origIndex);
115- SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
116- newMaskOperands.push_back (
117- getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
118- newMask = rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
119- newMaskOperands);
120- } else if (constantMaskOp) {
121- ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
122- size_t numMaskOperands = maskDimSizes.size ();
123- int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
124- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125- int64_t maskIndex =
126- llvm::divideCeil (numFrontPadElems + origIndex, numSrcElemsPerDest);
127-
128- // TODO: we only want the mask between [startIndex, maskIndex] to be true,
129- // the rest are false.
130- if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
131- return failure ();
132-
133- SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
134- newMaskDimSizes.push_back (maskIndex);
135-
136- if (numFrontPadElems == 0 ) {
137- newMask = rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
138- newMaskDimSizes);
139- } else {
140- SmallVector<bool > newMaskValues;
141- for (int64_t i = 0 ; i < numElements; ++i)
142- newMaskValues.push_back (i >= startIndex && i < maskIndex);
143- auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
144- newMask = rewriter.create <arith::ConstantOp>(loc, newMaskType, denseAttr);
145- }
146- }
105+ maskShape.back () = numDestElems;
106+ auto newMaskType = VectorType::get (maskShape, rewriter.getI1Type ());
107+ std::optional<Operation *> newMask =
108+ TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
109+ .Case <vector::CreateMaskOp>(
110+ [&](auto createMaskOp) -> std::optional<Operation *> {
111+ OperandRange maskOperands = createMaskOp.getOperands ();
112+ size_t numMaskOperands = maskOperands.size ();
113+ AffineExpr s0;
114+ bindSymbols (rewriter.getContext (), s0);
115+ s0 = s0 + numSrcElemsPerDest - 1 ;
116+ s0 = s0.floorDiv (numSrcElemsPerDest);
117+ OpFoldResult origIndex =
118+ getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
119+ OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply (
120+ rewriter, loc, s0, origIndex);
121+ SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
122+ newMaskOperands.push_back (
123+ getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
124+ return rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
125+ newMaskOperands);
126+ })
127+ .Case <vector::ConstantMaskOp>([&](auto constantMaskOp)
128+ -> std::optional<Operation *> {
129+ ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
130+ size_t numMaskOperands = maskDimSizes.size ();
131+ int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
132+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
133+ int64_t maskIndex = llvm::divideCeil (numFrontPadElems + origIndex,
134+ numSrcElemsPerDest);
135+
136+ // TODO: we only want the mask between [startIndex, maskIndex]
137+ // to be true, the rest are false.
138+ if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
139+ return std::nullopt ;
140+
141+ SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
142+ newMaskDimSizes.push_back (maskIndex);
143+
144+ if (numFrontPadElems == 0 )
145+ return rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
146+ newMaskDimSizes);
147+
148+ SmallVector<bool > newMaskValues;
149+ for (int64_t i = 0 ; i < numDestElems; ++i)
150+ newMaskValues.push_back (i >= startIndex && i < maskIndex);
151+ auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
152+ return rewriter.create <arith::ConstantOp>(loc, newMaskType,
153+ denseAttr);
154+ })
155+ .Case <arith::ConstantOp>([&](auto constantOp)
156+ -> std::optional<Operation *> {
157+ // TODO: Support multiple dimensions.
158+ if (maskShape.size () != 1 )
159+ return std::nullopt ;
160+ // Rearrange the original mask values to cover the whole potential
161+ // loading region. For example, in the case of using byte-size for
162+ // emulation, given the following mask:
163+ //
164+ // %mask = [false, true, false, true, false, false]
165+ //
166+ // With front offset of 1, the mask will be padded 0s in the front
167+ // and back so that:
168+ // 1. It is aligned with the effective loading bits
169+ // 2. Its length is multiple of `numSrcElemPerDest` (and the total
170+ // coverage size is mulitiple of bytes). The new mask will be like
171+ // this before compressing:
172+ //
173+ // %new_mask = [false, false, true, false, true, false, false,
174+ // false]
175+ auto denseAttr =
176+ cast<DenseIntElementsAttr>(constantOp.getValue ());
177+ SmallVector<bool > paddedMaskValues (numFrontPadElems, false );
178+ paddedMaskValues.append (denseAttr.template value_begin <bool >(),
179+ denseAttr.template value_end <bool >());
180+ paddedMaskValues.resize (numDestElems * numSrcElemsPerDest, false );
181+
182+ // Compressing by combining every `numSrcElemsPerDest` elements:
183+ SmallVector<bool > compressedMaskValues;
184+ for (size_t i = 0 ; i < paddedMaskValues.size (); i += numSrcElemsPerDest) {
185+ bool combinedValue = false ;
186+ for (int j = 0 ; j < numSrcElemsPerDest; ++j) {
187+ combinedValue |= paddedMaskValues[i + j];
188+ }
189+ compressedMaskValues.push_back (combinedValue);
190+ }
191+ return rewriter.create <arith::ConstantOp>(
192+ loc, DenseElementsAttr::get (newMaskType, compressedMaskValues));
193+ });
194+
195+ if (!newMask)
196+ return failure ();
147197
148198 while (!extractOps.empty ()) {
149199 newMask = rewriter.create <vector::ExtractOp>(
150- loc, newMask->getResults ()[0 ], extractOps.back ().getMixedPosition ());
200+ loc, (* newMask) ->getResults ()[0 ], extractOps.back ().getMixedPosition ());
151201 extractOps.pop_back ();
152202 }
153203
154- return newMask;
204+ return * newMask;
155205}
156206
157207// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
@@ -185,12 +235,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
185235// / `vector.insert_strided_slice`.
186236static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
187237 Value src, Value dest, int64_t offset) {
188- auto srcType = cast<VectorType>(src.getType ());
189- auto destType = cast<VectorType>(dest.getType ());
238+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
239+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
190240 assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
191241 " expected source and dest to be vector type" );
192- (void )srcType;
193- (void )destType;
194242 auto offsets = rewriter.getI64ArrayAttr ({offset});
195243 auto strides = rewriter.getI64ArrayAttr ({1 });
196244 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
0 commit comments