@@ -38,16 +38,17 @@ using namespace mlir;
3838
3939// / Returns a compressed mask. The mask value is set only if any mask is present
4040// / in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
41- // / equals to 2, the following mask:
41+ // / equals to 1 (intraDataOffset strictly smaller than scale), the following
42+ // / mask:
4243// /
43- // / %mask = [1, 1, 1 , 0, 0, 0]
44+ // / %mask = [1, 1, 0 , 0, 0, 0]
4445// /
4546// / will first be padded with number of `intraDataOffset` zeros:
46- // / %mask = [0, 0 , 1, 1, 1 , 0, 0, 0]
47+ // / %mask = [0, 1 , 1, 0, 0 , 0, 0, 0]
4748// /
4849// / then it will return the following new compressed mask:
4950// /
50- // / %mask = [0 , 1, 1 , 0]
51+ // / %mask = [1 , 1, 0 , 0]
5152static FailureOr<Operation *> getCompressedMaskOp (OpBuilder &rewriter,
5253 Location loc, Value mask,
5354 int origElements, int scale,
@@ -76,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7677 shape.back () = numElements;
7778 auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
7879 if (createMaskOp) {
79- // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
80- if (intraDataOffset != 0 )
81- return failure ();
8280 OperandRange maskOperands = createMaskOp.getOperands ();
8381 size_t numMaskOperands = maskOperands.size ();
8482 AffineExpr s0;
@@ -130,10 +128,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
130128 return newMask;
131129}
132130
133- // / A wrapper function for emitting `vector.extract_strided_slice`.
131+ // / A wrapper function for emitting `vector.extract_strided_slice`. The vector
132+ // / has to be of 1-D shape.
134133static Value extractSubvectorFrom (RewriterBase &rewriter, Location loc,
135134 VectorType extractType, Value vector,
136135 int64_t frontOffset, int64_t subvecSize) {
136+ // get vector's vector type:
137+ auto vectorType = dyn_cast<VectorType>(vector.getType ());
138+ assert (vectorType && " expected vector type" );
139+ assert (vectorType.getShape ().size () == 1 && " expected 1-D vector type" );
140+ assert (extractType.getShape ().size () == 1 &&
141+ " extractType must be 1-D vector type" );
142+
137143 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
138144 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
139145 auto strides = rewriter.getI64ArrayAttr ({1 });
@@ -143,9 +149,17 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
143149 ->getResult (0 );
144150}
145151
146- // / A wrapper function for emitting `vector.insert_strided_slice`.
152+ // / A wrapper function for emitting `vector.insert_strided_slice`. The source
153+ // / and dest vectors must be of 1-D shape.
147154static Value insertSubvectorInto (RewriterBase &rewriter, Location loc,
148155 Value src, Value dest, int64_t offset) {
156+ auto srcType = dyn_cast<VectorType>(src.getType ());
157+ assert (srcType && " expected vector type" );
158+ assert (srcType.getShape ().size () == 1 && " expected 1-D vector type" );
159+ auto destType = dyn_cast<VectorType>(dest.getType ());
160+ assert (destType && " expected vector type" );
161+ assert (destType.getShape ().size () == 1 && " expected 1-D vector type" );
162+
149163 auto offsets = rewriter.getI64ArrayAttr ({offset});
150164 auto strides = rewriter.getI64ArrayAttr ({1 });
151165 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
@@ -157,24 +171,20 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
157171// / `srcOffsetVar` is not a constant, making it impossible to use
158172// / vector.extract_strided_slice, as it requires constant offsets.
159173static Value dynamicallyExtractSubVector (RewriterBase &rewriter, Location loc,
160- TypedValue<VectorType> srcVec,
161- Value destVec,
162- OpFoldResult srcOffsetVar,
163- int64_t lengthSubvec) {
164- for (int i = 0 ; i < lengthSubvec; ++i) {
165- Value extractLoc;
166- if (i == 0 ) {
167- extractLoc = srcOffsetVar.dyn_cast <Value>();
168- } else {
169- extractLoc = rewriter.create <arith::AddIOp>(
170- loc, rewriter.getIndexType (), srcOffsetVar.dyn_cast <Value>(),
171- rewriter.create <arith::ConstantIndexOp>(loc, i));
172- }
174+ TypedValue<VectorType> source,
175+ Value dest, OpFoldResult offset,
176+ int64_t numElementsToExtract) {
177+ for (int i = 0 ; i < numElementsToExtract; ++i) {
178+ Value extractLoc =
179+ (i == 0 ) ? offset.dyn_cast <Value>()
180+ : rewriter.create <arith::AddIOp>(
181+ loc, rewriter.getIndexType (), offset.dyn_cast <Value>(),
182+ rewriter.create <arith::ConstantIndexOp>(loc, i));
173183 auto extractOp =
174- rewriter.create <vector::ExtractOp>(loc, srcVec , extractLoc);
175- destVec = rewriter.create <vector::InsertOp>(loc, extractOp, destVec , i);
184+ rewriter.create <vector::ExtractOp>(loc, source , extractLoc);
185+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest , i);
176186 }
177- return destVec ;
187+ return dest ;
178188}
179189
180190// / Load `numLoadedElements` of `newElementType` from `base` at
@@ -183,15 +193,15 @@ static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
183193static TypedValue<VectorType>
184194emulatedVectorLoad (ConversionPatternRewriter &rewriter, Location loc,
185195 Value base, OpFoldResult linearizedIndices,
186- int64_t numLoadedElements , Type oldElememtType,
196+ int64_t numElementsToLoad , Type oldElememtType,
187197 Type newElementType) {
188198 auto scale = newElementType.getIntOrFloatBitWidth () /
189199 oldElememtType.getIntOrFloatBitWidth ();
190200 auto newLoad = rewriter.create <vector::LoadOp>(
191- loc, VectorType::get (numLoadedElements , newElementType), base,
201+ loc, VectorType::get (numElementsToLoad , newElementType), base,
192202 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
193203 return rewriter.create <vector::BitCastOp>(
194- loc, VectorType::get (numLoadedElements * scale, oldElememtType), newLoad);
204+ loc, VectorType::get (numElementsToLoad * scale, oldElememtType), newLoad);
195205};
196206
197207namespace {
0 commit comments