@@ -75,83 +75,133 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7575 int numSrcElemsPerDest,
7676 int numFrontPadElems = 0 ) {
7777
78- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
78+ assert (numFrontPadElems < numSrcElemsPerDest &&
79+ " numFrontPadElems must be less than numSrcElemsPerDest" );
7980
80- auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81- numSrcElemsPerDest;
81+ auto numDestElems =
82+ (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
83+ numSrcElemsPerDest;
8284
8385 Operation *maskOp = mask.getDefiningOp ();
8486 SmallVector<vector::ExtractOp, 2 > extractOps;
87+ // TODO: add support to `vector.splat`.
8588 // Finding the mask creation operation.
86- while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
89+ while (maskOp &&
90+ !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
91+ maskOp)) {
8792 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
8893 maskOp = extractOp.getVector ().getDefiningOp ();
8994 extractOps.push_back (extractOp);
9095 }
9196 }
92- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
93- auto constantMaskOp = dyn_cast_or_null< vector::ConstantMaskOp>(maskOp);
94- if (!createMaskOp && !constantMaskOp )
97+
98+ if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99+ maskOp) )
95100 return failure ();
96101
97102 // Computing the "compressed" mask. All the emulation logic (i.e. computing
98103 // new mask index) only happens on the last dimension of the vectors.
99- Operation *newMask = nullptr ;
100- SmallVector<int64_t > shape (
104+ SmallVector<int64_t > maskShape (
101105 cast<VectorType>(maskOp->getResultTypes ()[0 ]).getShape ());
102- shape.back () = numElements;
103- auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
104- if (createMaskOp) {
105- OperandRange maskOperands = createMaskOp.getOperands ();
106- size_t numMaskOperands = maskOperands.size ();
107- AffineExpr s0;
108- bindSymbols (rewriter.getContext (), s0);
109- s0 = s0 + numSrcElemsPerDest - 1 ;
110- s0 = s0.floorDiv (numSrcElemsPerDest);
111- OpFoldResult origIndex =
112- getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
113- OpFoldResult maskIndex =
114- affine::makeComposedFoldedAffineApply (rewriter, loc, s0, origIndex);
115- SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
116- newMaskOperands.push_back (
117- getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
118- newMask = rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
119- newMaskOperands);
120- } else if (constantMaskOp) {
121- ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
122- size_t numMaskOperands = maskDimSizes.size ();
123- int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
124- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125- int64_t maskIndex =
126- llvm::divideCeil (numFrontPadElems + origIndex, numSrcElemsPerDest);
127-
128- // TODO: we only want the mask between [startIndex, maskIndex] to be true,
129- // the rest are false.
130- if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
131- return failure ();
132-
133- SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
134- newMaskDimSizes.push_back (maskIndex);
135-
136- if (numFrontPadElems == 0 ) {
137- newMask = rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
138- newMaskDimSizes);
139- } else {
140- SmallVector<bool > newMaskValues;
141- for (int64_t i = 0 ; i < numElements; ++i)
142- newMaskValues.push_back (i >= startIndex && i < maskIndex);
143- auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
144- newMask = rewriter.create <arith::ConstantOp>(loc, newMaskType, denseAttr);
145- }
146- }
106+ maskShape.back () = numDestElems;
107+ auto newMaskType = VectorType::get (maskShape, rewriter.getI1Type ());
108+ std::optional<Operation *> newMask =
109+ TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
110+ .Case <vector::CreateMaskOp>(
111+ [&](auto createMaskOp) -> std::optional<Operation *> {
112+ OperandRange maskOperands = createMaskOp.getOperands ();
113+ size_t numMaskOperands = maskOperands.size ();
114+ AffineExpr s0;
115+ bindSymbols (rewriter.getContext (), s0);
116+ s0 = s0 + numSrcElemsPerDest - 1 ;
117+ s0 = s0.floorDiv (numSrcElemsPerDest);
118+ OpFoldResult origIndex =
119+ getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
120+ OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply (
121+ rewriter, loc, s0, origIndex);
122+ SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
123+ newMaskOperands.push_back (
124+ getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
125+ return rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
126+ newMaskOperands);
127+ })
128+ .Case <vector::ConstantMaskOp>([&](auto constantMaskOp)
129+ -> std::optional<Operation *> {
130+ ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
131+ size_t numMaskOperands = maskDimSizes.size ();
132+ int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
133+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
134+ int64_t maskIndex = llvm::divideCeil (numFrontPadElems + origIndex,
135+ numSrcElemsPerDest);
136+
137+ // TODO: we only want the mask between [startIndex, maskIndex]
138+ // to be true, the rest are false.
139+ if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
140+ return std::nullopt ;
141+
142+ SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
143+ newMaskDimSizes.push_back (maskIndex);
144+
145+ if (numFrontPadElems == 0 )
146+ return rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
147+ newMaskDimSizes);
148+
149+ SmallVector<bool > newMaskValues;
150+ for (int64_t i = 0 ; i < numDestElems; ++i)
151+ newMaskValues.push_back (i >= startIndex && i < maskIndex);
152+ auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
153+ return rewriter.create <arith::ConstantOp>(loc, newMaskType,
154+ denseAttr);
155+ })
156+ .Case <arith::ConstantOp>([&](auto constantOp)
157+ -> std::optional<Operation *> {
158+ // TODO: Support multiple dimensions.
159+ if (maskShape.size () != 1 )
160+ return std::nullopt ;
161+ // Rearrange the original mask values to cover the whole potential
162+ // loading region. For example, in the case of using byte-size for
163+ // emulation, given the following mask:
164+ //
165+ // %mask = [0, 1, 0, 1, 0, 0]
166+ //
167+ // With front offset of 1, the mask will be padded 0s in the front
168+ // and back so that:
169+ // 1. It is aligned with the effective loading bits
170+ // 2. Its length is multiple of `numSrcElemPerDest` (and the total
171+ // coverage size is mulitiple of bytes). The new mask will be like
172+ // this before compressing:
173+ //
174+ // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
175+ auto denseAttr = cast<DenseIntElementsAttr>(constantOp.getValue ());
176+ SmallVector<bool > paddedMaskValues (numFrontPadElems, false );
177+ paddedMaskValues.append (denseAttr.template value_begin <bool >(),
178+ denseAttr.template value_end <bool >());
179+ paddedMaskValues.resize (numDestElems * numSrcElemsPerDest, false );
180+
181+ // Compressing by combining every `numSrcElemsPerDest` elements:
182+ SmallVector<bool > compressedMaskValues;
183+ for (size_t i = 0 ; i < paddedMaskValues.size ();
184+ i += numSrcElemsPerDest) {
185+ bool combinedValue = false ;
186+ for (int j = 0 ; j < numSrcElemsPerDest; ++j) {
187+ combinedValue |= paddedMaskValues[i + j];
188+ }
189+ compressedMaskValues.push_back (combinedValue);
190+ }
191+ return rewriter.create <arith::ConstantOp>(
192+ loc, DenseElementsAttr::get (newMaskType, compressedMaskValues));
193+ });
194+
195+ if (!newMask)
196+ return failure ();
147197
148198 while (!extractOps.empty ()) {
149199 newMask = rewriter.create <vector::ExtractOp>(
150- loc, newMask->getResults ()[0 ], extractOps.back ().getMixedPosition ());
200+ loc, (* newMask) ->getResults ()[0 ], extractOps.back ().getMixedPosition ());
151201 extractOps.pop_back ();
152202 }
153203
154- return newMask;
204+ return * newMask;
155205}
156206
157207// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
@@ -185,12 +235,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
185235// / `vector.insert_strided_slice`.
186236static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
187237 Value src, Value dest, int64_t offset) {
188- auto srcType = cast<VectorType>(src.getType ());
189- auto destType = cast<VectorType>(dest.getType ());
238+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
239+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
190240 assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
191241 " expected source and dest to be vector type" );
192- (void )srcType;
193- (void )destType;
194242 auto offsets = rewriter.getI64ArrayAttr ({offset});
195243 auto strides = rewriter.getI64ArrayAttr ({1 });
196244 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
0 commit comments