Skip to content

Commit 461b969

Browse files
committed
Split up the and operation over the cycles
- allows for the shifts to only compute the shift on the required cycle and not the entire input - removes the initial vector.shuffle as we can operate directly on the input and merge them after
1 parent e29a6da commit 461b969

File tree

2 files changed

+44
-63
lines changed

2 files changed

+44
-63
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -837,23 +837,14 @@ static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
837837
return true;
838838
}
839839

840-
static SmallVector<int64_t> constructShuffles(int64_t inputSize,
841-
int64_t numCycles,
840+
static SmallVector<int64_t> constructShuffles(int64_t numCycles,
842841
int64_t cycleLen, int64_t idx) {
843-
// If idx == 1, then the first operand of the shuffle will be the mask which
844-
// will have the original size. So we need to step through the mask with a
845-
// stride of cycleSize.
846-
// When idx > 1, then the first operand will be the size of (idx * cycleSize)
847-
// and so we take the first idx elements of the input and then append the
848-
// strided mask value.
849-
int64_t inputStride = idx == 1 ? cycleLen : idx;
850-
851842
SmallVector<int64_t> shuffles;
852843
for (int64_t cycle = 0; cycle < numCycles; cycle++) {
853844
for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
854-
shuffles.push_back(cycle * inputStride + inputIdx);
845+
shuffles.push_back(cycle * idx + inputIdx);
855846
}
856-
shuffles.push_back(inputSize + cycle * cycleLen + idx);
847+
shuffles.push_back(numCycles * idx + cycle);
857848
}
858849
return shuffles;
859850
}
@@ -917,47 +908,31 @@ Value BitCastRewriter::splatRewriteStep(
917908
PatternRewriter &rewriter, Location loc, Value initialValue,
918909
Value runningResult, const BitCastRewriter::Metadata &metadata) {
919910

920-
// Initial result will be the Shifted Mask which will have the shuffles size.
921-
int64_t inputSize = metadata.shuffles.size();
922-
int64_t numCycles = inputSize / cycleLen;
923-
924-
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
925-
loc, initialValue, initialValue, metadata.shuffles);
926-
927-
// Intersect with the mask.
928-
VectorType shuffledVectorType = shuffleOp.getResultVectorType();
929-
auto constOp = rewriter.create<arith::ConstantOp>(
930-
loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
931-
Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
932-
911+
int64_t numCycles = metadata.shuffles.size() / cycleLen;
912+
ShapedType vectorType = dyn_cast<ShapedType>(initialValue.getType());
933913
Value result;
934914
for (int64_t idx = 0; idx < cycleLen; idx++) {
915+
// Intersect with the mask.
916+
auto constOp = rewriter.create<arith::ConstantOp>(
917+
loc, DenseElementsAttr::get(vectorType, metadata.masks[idx]));
918+
Value andValue = rewriter.create<arith::AndIOp>(loc, initialValue, constOp);
919+
935920
auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
936-
loc, SplatElementsAttr::get(shuffledVectorType,
937-
metadata.shiftRightAmounts[idx]));
921+
loc,
922+
SplatElementsAttr::get(vectorType, metadata.shiftRightAmounts[idx]));
938923
Value shiftedRight =
939924
rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
940925

941926
auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
942-
loc, SplatElementsAttr::get(shuffledVectorType,
943-
metadata.shiftLeftAmounts[idx]));
927+
loc,
928+
SplatElementsAttr::get(vectorType, metadata.shiftLeftAmounts[idx]));
944929
Value shiftedLeft =
945930
rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
946931

947-
if (result) {
948-
SmallVector<int64_t> shuffles =
949-
constructShuffles(inputSize, numCycles, cycleLen, idx);
950-
result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
951-
shuffles);
952-
953-
// After the first shuffle in the chain, the size of the input result will
954-
// grow as we append more shuffles together to reconstruct the
955-
// shuffledVectorType size. Each iteration they will retain numCycles more
956-
// elements.
957-
inputSize = numCycles * (idx + 1);
958-
} else {
959-
result = shiftedLeft;
960-
}
932+
SmallVector<int64_t> shuffles = constructShuffles(numCycles, cycleLen, idx);
933+
result = result ? rewriter.create<vector::ShuffleOp>(loc, result,
934+
shiftedLeft, shuffles)
935+
: shiftedLeft;
961936
}
962937

963938
return result;

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,15 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
230230
}
231231

232232
// CHECK-LABEL: func.func @fext_splat1(
233-
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
233+
// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
234234
func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
235-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, -16, 15, -16]> : vector<4xi8>
236-
// CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<4xi8>
237-
// CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1] : vector<2xi8>, vector<2xi8>
238-
// CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<4xi8>
239-
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST]] : vector<4xi8>
240-
// CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 5, 2, 7] : vector<4xi8>, vector<4xi8>
235+
// CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<2xi8>
236+
// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<-16> : vector<2xi8>
237+
// CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<2xi8>
238+
// CHECK-DAG: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<2xi8>
239+
// CHECK-DAG: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<2xi8>
240+
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST]] : vector<2xi8>
241+
// CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 2, 1, 3] : vector<2xi8>, vector<2xi8>
241242
// CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
242243
// return %[[RES]] : vector<4xi16>
243244
%0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
@@ -246,20 +247,25 @@ func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
246247
}
247248

248249
// CHECK-LABEL: func.func @fext_splat2(
249-
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
250+
// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
250251
func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
251-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, 240, 3840, -4096, 15, 240, 3840, -4096, 15, 240, 3840, -4096]> : vector<12xi16>
252-
// CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<12xi16>
253-
// CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<12xi16>
254-
// CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<12xi16>
255-
// CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] : vector<3xi16>, vector<3xi16>
256-
// CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<12xi16>
257-
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST0]] : vector<12xi16>
258-
// CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 13, 4, 17, 8, 21] : vector<12xi16>, vector<12xi16>
259-
// CHECK: %[[SHR1:.*]] = arith.shrui %[[A0]], %[[SHR_CST1]] : vector<12xi16>
260-
// CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 8, 2, 3, 12, 4, 5, 16] : vector<6xi16>, vector<12xi16>
261-
// CHECK: %[[SHR2:.*]] = arith.shrui %[[A0]], %[[SHR_CST2]] : vector<12xi16>
262-
// CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 12, 3, 4, 5, 16, 6, 7, 8, 20] : vector<9xi16>, vector<12xi16>
252+
// CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<3xi16>
253+
// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<240> : vector<3xi16>
254+
// CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<3840> : vector<3xi16>
255+
// CHECK-DAG: %[[MASK3:.*]] = arith.constant dense<-4096> : vector<3xi16>
256+
// CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<3xi16>
257+
// CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<3xi16>
258+
// CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<3xi16>
259+
// CHECK: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<3xi16>
260+
// CHECK: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<3xi16>
261+
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST0]] : vector<3xi16>
262+
// CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 3, 1, 4, 2, 5] : vector<3xi16>, vector<3xi16>
263+
// CHECK: %[[A2:.*]] = arith.andi %[[ARG]], %[[MASK2]] : vector<3xi16>
264+
// CHECK: %[[SHR1:.*]] = arith.shrui %[[A2]], %[[SHR_CST1]] : vector<3xi16>
265+
// CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 6, 2, 3, 7, 4, 5, 8] : vector<6xi16>, vector<3xi16>
266+
// CHECK: %[[A3:.*]] = arith.andi %[[ARG]], %[[MASK3]] : vector<3xi16>
267+
// CHECK: %[[SHR2:.*]] = arith.shrui %[[A3]], %[[SHR_CST2]] : vector<3xi16>
268+
// CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 9, 3, 4, 5, 10, 6, 7, 8, 11] : vector<9xi16>, vector<3xi16>
263269
// CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
264270
// CHEKC: return %[[RES]] : vector<12xi32>
265271
%0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>

0 commit comments

Comments
 (0)