18
18
#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
19
19
#include " mlir/IR/BuiltinAttributes.h"
20
20
#include " mlir/IR/BuiltinTypes.h"
21
+ #include " mlir/IR/OpDefinition.h"
21
22
#include " mlir/IR/TypeUtilities.h"
22
23
#include " mlir/IR/Value.h"
23
24
#include " mlir/Transforms/DialectConversion.h"
@@ -37,16 +38,17 @@ using namespace mlir;
37
38
38
39
// / Returns a compressed mask. The mask value is set only if any mask is present
39
40
// / in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
40
- // / equals to 2, the following mask:
41
+ // / equals to 1 (intraDataOffset strictly smaller than scale), the following
42
+ // / mask:
41
43
// /
42
- // / %mask = [1, 1, 1 , 0, 0, 0]
44
+ // / %mask = [1, 1, 0 , 0, 0, 0]
43
45
// /
44
46
// / will first be padded with number of `intraDataOffset` zeros:
45
- // / %mask = [0, 0 , 1, 1, 1 , 0, 0, 0]
47
+ // / %mask = [0, 1 , 1, 0, 0 , 0, 0, 0]
46
48
// /
47
49
// / then it will return the following new compressed mask:
48
50
// /
49
- // / %mask = [0 , 1, 1 , 0]
51
+ // / %mask = [1 , 1, 0 , 0]
50
52
static FailureOr<Operation *> getCompressedMaskOp (OpBuilder &rewriter,
51
53
Location loc, Value mask,
52
54
int origElements, int scale,
@@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
75
77
shape.back () = numElements;
76
78
auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
77
79
if (createMaskOp) {
78
- // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
79
- if (intraDataOffset != 0 )
80
- return failure ();
81
80
OperandRange maskOperands = createMaskOp.getOperands ();
82
81
size_t numMaskOperands = maskOperands.size ();
83
82
AffineExpr s0;
@@ -129,26 +128,79 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
129
128
return newMask;
130
129
}
131
130
132
- static Value extractSubvectorFrom (RewriterBase &rewriter, Location loc,
133
- VectorType extractType, Value vector,
134
- int64_t frontOffset, int64_t subvecSize) {
131
+ // / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
132
+ // / emitting `vector.extract_strided_slice`.
133
+ static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
134
+ VectorType extractType, Value source,
135
+ int64_t frontOffset,
136
+ int64_t subvecSize) {
137
+ auto vectorType = cast<VectorType>(source.getType ());
138
+ assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
139
+ " expected 1-D source and destination types" );
135
140
auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
136
141
auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
137
142
auto strides = rewriter.getI64ArrayAttr ({1 });
138
143
return rewriter
139
- .create <vector::ExtractStridedSliceOp>(loc, extractType, vector , offsets,
144
+ .create <vector::ExtractStridedSliceOp>(loc, extractType, source , offsets,
140
145
sizes, strides)
141
146
->getResult (0 );
142
147
}
143
148
144
- static Value insertSubvectorInto (RewriterBase &rewriter, Location loc,
145
- Value src, Value dest, int64_t offset) {
149
+ // / Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
150
+ // / at `offset`. it is a wrapper function for emitting
151
+ // / `vector.insert_strided_slice`.
152
+ static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
153
+ Value src, Value dest, int64_t offset) {
154
+ auto srcType = cast<VectorType>(src.getType ());
155
+ auto destType = cast<VectorType>(dest.getType ());
156
+ assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
157
+ " expected source and dest to be vector type" );
146
158
auto offsets = rewriter.getI64ArrayAttr ({offset});
147
159
auto strides = rewriter.getI64ArrayAttr ({1 });
148
160
return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
149
161
dest, offsets, strides);
150
162
}
151
163
164
+ // / Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
165
+ // / and size `numElementsToExtract`, and inserts into the `dest` vector. This
166
+ // / function emits multiple `vector.extract` and `vector.insert` ops, so only
167
+ // / use it when `offset` cannot be folded into a constant value.
168
+ static Value dynamicallyExtractSubVector (OpBuilder &rewriter, Location loc,
169
+ TypedValue<VectorType> source,
170
+ Value dest, OpFoldResult offset,
171
+ int64_t numElementsToExtract) {
172
+ for (int i = 0 ; i < numElementsToExtract; ++i) {
173
+ Value extractLoc =
174
+ (i == 0 ) ? offset.dyn_cast <Value>()
175
+ : rewriter.create <arith::AddIOp>(
176
+ loc, rewriter.getIndexType (), offset.dyn_cast <Value>(),
177
+ rewriter.create <arith::ConstantIndexOp>(loc, i));
178
+ auto extractOp =
179
+ rewriter.create <vector::ExtractOp>(loc, source, extractLoc);
180
+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, i);
181
+ }
182
+ return dest;
183
+ }
184
+
185
+ // / Returns the op sequence for an emulated sub-byte data type vector load.
186
+ // / specifically, use `emulatedElemType` for loading a vector of `origElemType`.
187
+ // / The load location is given by `base` and `linearizedIndices`, and the
188
+ // / load size is given by `numEmulatedElementsToLoad`.
189
+ static TypedValue<VectorType>
190
+ emulatedVectorLoad (OpBuilder &rewriter, Location loc, Value base,
191
+ OpFoldResult linearizedIndices,
192
+ int64_t numEmultedElementsToLoad, Type origElemType,
193
+ Type emulatedElemType) {
194
+ auto scale = emulatedElemType.getIntOrFloatBitWidth () /
195
+ origElemType.getIntOrFloatBitWidth ();
196
+ auto newLoad = rewriter.create <vector::LoadOp>(
197
+ loc, VectorType::get (numEmultedElementsToLoad, emulatedElemType), base,
198
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
199
+ return rewriter.create <vector::BitCastOp>(
200
+ loc, VectorType::get (numEmultedElementsToLoad * scale, origElemType),
201
+ newLoad);
202
+ };
203
+
152
204
namespace {
153
205
154
206
// ===----------------------------------------------------------------------===//
@@ -380,25 +432,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
380
432
? getConstantIntValue (linearizedInfo.intraDataOffset )
381
433
: 0 ;
382
434
383
- if (!foldedIntraVectorOffset) {
384
- // unimplemented case for dynamic intra vector offset
385
- return failure ();
386
- }
387
-
435
+ // Always load enough elements which can cover the original elements.
436
+ int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
388
437
auto numElements =
389
- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
390
- auto newLoad = rewriter.create <vector::LoadOp>(
391
- loc, VectorType::get (numElements, newElementType), adaptor.getBase (),
392
- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
393
-
394
- Value result = rewriter.create <vector::BitCastOp>(
395
- loc, VectorType::get (numElements * scale, oldElementType), newLoad);
396
-
397
- if (isUnalignedEmulation) {
398
- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
399
- *foldedIntraVectorOffset, origElements);
438
+ llvm::divideCeil (maxintraDataOffset + origElements, scale);
439
+ Value result =
440
+ emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
441
+ numElements, oldElementType, newElementType);
442
+
443
+ if (foldedIntraVectorOffset) {
444
+ if (isUnalignedEmulation) {
445
+ result =
446
+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
447
+ *foldedIntraVectorOffset, origElements);
448
+ }
449
+ } else {
450
+ auto resultVector = rewriter.create <arith::ConstantOp>(
451
+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
452
+ result = dynamicallyExtractSubVector (
453
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
454
+ linearizedInfo.intraDataOffset , origElements);
400
455
}
401
-
402
456
rewriter.replaceOp (op, result);
403
457
return success ();
404
458
}
@@ -513,8 +567,8 @@ struct ConvertVectorMaskedLoad final
513
567
// create an empty vector of the new type
514
568
auto emptyVector = rewriter.create <arith::ConstantOp>(
515
569
loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
516
- passthru = insertSubvectorInto (rewriter, loc, passthru, emptyVector,
517
- *foldedIntraVectorOffset);
570
+ passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
571
+ *foldedIntraVectorOffset);
518
572
}
519
573
auto newPassThru =
520
574
rewriter.create <vector::BitCastOp>(loc, loadType, passthru);
@@ -537,16 +591,17 @@ struct ConvertVectorMaskedLoad final
537
591
// TODO: can fold if op's mask is constant
538
592
auto emptyVector = rewriter.create <arith::ConstantOp>(
539
593
loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
540
- mask = insertSubvectorInto (rewriter, loc, op.getMask (), emptyVector,
541
- *foldedIntraVectorOffset);
594
+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyVector,
595
+ *foldedIntraVectorOffset);
542
596
}
543
597
544
598
Value result =
545
599
rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
546
600
547
601
if (isUnalignedEmulation) {
548
- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
549
- *foldedIntraVectorOffset, origElements);
602
+ result =
603
+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
604
+ *foldedIntraVectorOffset, origElements);
550
605
}
551
606
rewriter.replaceOp (op, result);
552
607
@@ -604,13 +659,10 @@ struct ConvertVectorTransferRead final
604
659
? getConstantIntValue (linearizedInfo.intraDataOffset )
605
660
: 0 ;
606
661
607
- if (!foldedIntraVectorOffset) {
608
- // unimplemented case for dynamic inra-vector offset
609
- return failure ();
610
- }
611
-
662
+ auto maxIntraVectorOffset =
663
+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
612
664
auto numElements =
613
- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
665
+ llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
614
666
615
667
auto newRead = rewriter.create <vector::TransferReadOp>(
616
668
loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
@@ -621,9 +673,18 @@ struct ConvertVectorTransferRead final
621
673
loc, VectorType::get (numElements * scale, oldElementType), newRead);
622
674
623
675
Value result = bitCast->getResult (0 );
624
- if (isUnalignedEmulation) {
625
- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
626
- *foldedIntraVectorOffset, origElements);
676
+ if (foldedIntraVectorOffset) {
677
+ if (isUnalignedEmulation) {
678
+ result =
679
+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
680
+ *foldedIntraVectorOffset, origElements);
681
+ }
682
+ } else {
683
+ auto zeros = rewriter.create <arith::ConstantOp>(
684
+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
685
+ result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
686
+ linearizedInfo.intraDataOffset ,
687
+ origElements);
627
688
}
628
689
rewriter.replaceOp (op, result);
629
690
0 commit comments