diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index f6371f39c3944..81d66e7b6ab17 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -318,6 +318,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op:$max_cycle_len); + let assemblyFormat = "attr-dict"; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 453fa73429dd1..9cd3d80c441d9 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -376,7 +376,8 @@ FailureOr rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, /// ops over wider types. /// Warning: these patterns currently only work for little endian targets. void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); + PatternBenefit benefit = 1, + unsigned shiftDepth = 0); /// Appends patterns for emulating a sub-byte vector transpose. void populateVectorTransposeNarrowTypeRewritePatterns( diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 885644864c0f7..e8c9033da7268 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -166,7 +166,8 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns( void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( RewritePatternSet &patterns) { - populateVectorNarrowTypeRewritePatterns(patterns); + populateVectorNarrowTypeRewritePatterns(patterns, /*default=*/1, + getMaxCycleLen()); populateVectorTransposeNarrowTypeRewritePatterns(patterns); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index dc6f126aae4c8..425431b7fa343 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector { /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as: /// [0] = {0, [0, 10)}, {1, [0, 5)} /// [1] = {1, [5, 10)}, {2, [0, 10)} +/// and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as: +/// [0] = {0, [0, 4)}, {1, [0, 4)} +/// [1] = {2, [0, 4)}, {3, [0, 4)} +/// and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as: +/// [0] = {0, [0, 4)} +/// [1] = {0, [4, 8)} +/// [2] = {1, [0, 4)} +/// [3] = {1, [4, 8)} struct BitCastBitsEnumerator { BitCastBitsEnumerator(VectorType sourceVectorType, VectorType targetVectorType); @@ -633,6 +641,35 @@ struct BitCastBitsEnumerator { /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the /// bits extracted from the source vector (i.e. the `shuffle -> and` part). +/// +/// +/// When we consider the above algorithm to rewrite our vector.bitcast, we rely +/// on using dynamic shift amounts for the left and right shifts. This can be +/// inefficient on certain targets (RDNA GPUs) in contrast to a splat constant +/// value. So when possible we can rewrite this as a combination of shifts with +/// a constant splat value and then regroup the selected terms. +/// +/// Eg. Instead of: +/// res = arith.shrui x [0, 4, 8, 0, 4, 8] +/// use: +/// y = arith.shrui x [0, 0, 0, 0, 0, 0] (can be folded away) +/// y1 = arith.shrui x [4, 4, 4, 4, 4, 4] +/// y2 = arith.shrui x [8, 8, 8, 8, 8, 8] +/// y3 = vector.shuffle y y1 [0, 7, 3, 10] +/// res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10] +/// +/// This is possible when the precomputed shift amounts following a cyclic +/// pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length, +/// cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are +/// of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in +/// (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows +/// the same 2 steps as above, then it proceeds as follows: +/// +/// 2. for each element in the cycle, x, of the rightShiftAmounts create a +/// shrui with a splat constant of x. +/// 3. repeat 2. with the respective leftShiftAmounts +/// 4. construct a chain of vector.shuffles that will reconstruct the result +/// from the chained shifts struct BitCastRewriter { /// Helper metadata struct to hold the static quantities for the rewrite. struct Metadata { @@ -656,10 +693,25 @@ struct BitCastRewriter { Value initialValue, Value runningResult, const BitCastRewriter::Metadata &metadata); + /// Rewrite one step of the sequence when able to use a splat constant for the + /// shiftright and shiftleft. + Value splatRewriteStep(PatternRewriter &rewriter, Location loc, + Value initialValue, Value runningResult, + const BitCastRewriter::Metadata &metadata); + + bool useSplatStep(unsigned maxCycleLen) { + return 1 < cycleLen && cycleLen <= maxCycleLen; + } + private: /// Underlying enumerator that encodes the provenance of the bits in the each /// element of the result vector. BitCastBitsEnumerator enumerator; + + // Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1 + // denotes that there is a cycle in the precomputed shift amounts and we are + // able to use the splatRewriteStep. + int64_t cycleLen = 0; }; } // namespace @@ -775,8 +827,31 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, return success(); } +// Check if the vector is a cycle of the first cycleLen elements. +template +static bool isCyclic(SmallVector xs, int64_t cycleLen) { + for (int64_t idx = cycleLen, n = xs.size(); idx < n; idx++) { + if (xs[idx] != xs[idx % cycleLen]) + return false; + } + return true; +} + +static SmallVector constructShuffles(int64_t numCycles, + int64_t cycleLen, int64_t idx) { + SmallVector shuffles; + for (int64_t cycle = 0; cycle < numCycles; cycle++) { + for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) { + shuffles.push_back(cycle * idx + inputIdx); + } + shuffles.push_back(numCycles * idx + cycle); + } + return shuffles; +} + SmallVector BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { + bool cyclicShifts = true; SmallVector result; for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries(); shuffleIdx < e; ++shuffleIdx) { @@ -811,8 +886,55 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { IntegerAttr::get(shuffledElementType, shiftLeft)); } + // Compute a potential cycle size by detecting the number of sourceElements + // at the start of shuffle that are the same + cycleLen = 1; + for (int64_t n = shuffles.size(); cycleLen < n; cycleLen++) + if (shuffles[cycleLen] != shuffles[0]) + break; + + cyclicShifts = cyclicShifts && (cycleLen < (int64_t)shuffles.size()) && + isCyclic(shiftRightAmounts, cycleLen) && + isCyclic(shiftLeftAmounts, cycleLen); + result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts}); } + + cycleLen = cyclicShifts ? cycleLen : 0; + return result; +} + +Value BitCastRewriter::splatRewriteStep( + PatternRewriter &rewriter, Location loc, Value initialValue, + Value runningResult, const BitCastRewriter::Metadata &metadata) { + + int64_t numCycles = metadata.shuffles.size() / cycleLen; + ShapedType vectorType = dyn_cast(initialValue.getType()); + Value result; + for (int64_t idx = 0; idx < cycleLen; idx++) { + // Intersect with the mask. + auto constOp = rewriter.create( + loc, DenseElementsAttr::get(vectorType, metadata.masks[idx])); + Value andValue = rewriter.create(loc, initialValue, constOp); + + auto shiftRightConstantOp = rewriter.create( + loc, + SplatElementsAttr::get(vectorType, metadata.shiftRightAmounts[idx])); + Value shiftedRight = + rewriter.create(loc, andValue, shiftRightConstantOp); + + auto shiftLeftConstantOp = rewriter.create( + loc, + SplatElementsAttr::get(vectorType, metadata.shiftLeftAmounts[idx])); + Value shiftedLeft = + rewriter.create(loc, shiftedRight, shiftLeftConstantOp); + + SmallVector shuffles = constructShuffles(numCycles, cycleLen, idx); + result = result ? rewriter.create(loc, result, + shiftedLeft, shuffles) + : shiftedLeft; + } + return result; } @@ -939,6 +1061,11 @@ namespace { struct RewriteBitCastOfTruncI : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + RewriteBitCastOfTruncI(MLIRContext *context, PatternBenefit benefit, + unsigned maxCycleLen) + : OpRewritePattern(context, benefit), + maxCycleLen{maxCycleLen} {} + LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, PatternRewriter &rewriter) const override { // The source must be a trunc op. @@ -961,8 +1088,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern { Value runningResult; for (const BitCastRewriter ::Metadata &metadata : bcr.precomputeMetadata(shuffledElementType)) { - runningResult = bcr.genericRewriteStep( - rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata); + runningResult = + bcr.useSplatStep(maxCycleLen) + ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), truncValue, + runningResult, metadata) + : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(), + truncValue, runningResult, metadata); } // Finalize the rewrite. @@ -986,6 +1117,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern { return success(); } + +private: + unsigned maxCycleLen; }; } // namespace @@ -1001,8 +1135,10 @@ template struct RewriteExtOfBitCast : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit) - : OpRewritePattern(context, benefit) {} + RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit, + unsigned maxCycleLen) + : OpRewritePattern(context, benefit), + maxCycleLen{maxCycleLen} {} LogicalResult matchAndRewrite(ExtOpType extOp, PatternRewriter &rewriter) const override { @@ -1026,8 +1162,12 @@ struct RewriteExtOfBitCast : OpRewritePattern { cast(getElementTypeOrSelf(sourceValue.getType())); for (const BitCastRewriter::Metadata &metadata : bcr.precomputeMetadata(shuffledElementType)) { - runningResult = bcr.genericRewriteStep( - rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata); + runningResult = + bcr.useSplatStep(maxCycleLen) + ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), sourceValue, + runningResult, metadata) + : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(), + sourceValue, runningResult, metadata); } // Finalize the rewrite. @@ -1044,6 +1184,9 @@ struct RewriteExtOfBitCast : OpRewritePattern { return success(); } + +private: + unsigned maxCycleLen; }; /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and @@ -1222,10 +1365,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns( } void vector::populateVectorNarrowTypeRewritePatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, unsigned maxCycleLen) { patterns.add, RewriteExtOfBitCast>(patterns.getContext(), - benefit); + benefit, maxCycleLen); // Patterns for aligned cases. We set higher priority as they are expected to // generate better performance for aligned cases. diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 8f0148119806c..396a9e9ee2cb5 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -146,6 +146,42 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> { return %1 : vector<8xi6> } +// CHECK-LABEL: func.func @ftrunc_splat1( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi16>) -> vector<1xi8> { +func.func @ftrunc_splat1(%a: vector<2xi16>) -> vector<1xi8> { + // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<1xi16> + // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<1xi16> + // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0] : vector<2xi16>, vector<2xi16> + // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<1xi16> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1] : vector<2xi16>, vector<2xi16> + // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<1xi16> + // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<1xi16> + // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<1xi16> + // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<1xi16> to vector<1xi8> + // return %[[RES]] : vector<1xi8> + %0 = arith.trunci %a : vector<2xi16> to vector<2xi4> + %1 = vector.bitcast %0 : vector<2xi4> to vector<1xi8> + return %1 : vector<1xi8> +} + +// CHECK-LABEL: func.func @ftrunc_splat2( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<2xi8> { +func.func @ftrunc_splat2(%a: vector<4xi16>) -> vector<2xi8> { + // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<2xi16> + // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<2xi16> + // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 2] : vector<4xi16>, vector<4xi16> + // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<2xi16> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 3] : vector<4xi16>, vector<4xi16> + // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<2xi16> + // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<2xi16> + // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<2xi16> + // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<2xi16> to vector<2xi8> + // return %[[RES]] : vector<2xi8> + %0 = arith.trunci %a : vector<4xi16> to vector<4xi4> + %1 = vector.bitcast %0 : vector<4xi4> to vector<2xi8> + return %1 : vector<2xi8> +} + // CHECK-LABEL: func.func @f1ext( // CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> { func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> { @@ -193,6 +229,50 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { return %1 : vector<8xi17> } +// CHECK-LABEL: func.func @fext_splat1( +// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> { +func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> { + // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<2xi8> + // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<-16> : vector<2xi8> + // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<2xi8> + // CHECK-DAG: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<2xi8> + // CHECK-DAG: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<2xi8> + // CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST]] : vector<2xi8> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 2, 1, 3] : vector<2xi8>, vector<2xi8> + // CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16> + // return %[[RES]] : vector<4xi16> + %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4> + %1 = arith.extui %0 : vector<4xi4> to vector<4xi16> + return %1 : vector<4xi16> +} + +// CHECK-LABEL: func.func @fext_splat2( +// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> { +func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> { + // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<3xi16> + // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<240> : vector<3xi16> + // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<3840> : vector<3xi16> + // CHECK-DAG: %[[MASK3:.*]] = arith.constant dense<-4096> : vector<3xi16> + // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<3xi16> + // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<3xi16> + // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<3xi16> + // CHECK: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<3xi16> + // CHECK: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<3xi16> + // CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST0]] : vector<3xi16> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 3, 1, 4, 2, 5] : vector<3xi16>, vector<3xi16> + // CHECK: %[[A2:.*]] = arith.andi %[[ARG]], %[[MASK2]] : vector<3xi16> + // CHECK: %[[SHR1:.*]] = arith.shrui %[[A2]], %[[SHR_CST1]] : vector<3xi16> + // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 6, 2, 3, 7, 4, 5, 8] : vector<6xi16>, vector<3xi16> + // CHECK: %[[A3:.*]] = arith.andi %[[ARG]], %[[MASK3]] : vector<3xi16> + // CHECK: %[[SHR2:.*]] = arith.shrui %[[A3]], %[[SHR_CST2]] : vector<3xi16> + // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 9, 3, 4, 5, 10, 6, 7, 8, 11] : vector<9xi16>, vector<3xi16> + // CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32> + // CHEKC: return %[[RES]] : vector<12xi32> + %0 = vector.bitcast %a : vector<3xi16> to vector<12xi4> + %1 = arith.extui %0 : vector<12xi4> to vector<12xi32> + return %1 : vector<12xi32> +} + // CHECK-LABEL: func.func @aligned_extsi( func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { @@ -330,7 +410,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %f { - transform.apply_patterns.vector.rewrite_narrow_types + transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 } } : !transform.any_op transform.yield } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir index a0b39a2b68f43..a7e13ea1a79c4 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir @@ -124,6 +124,36 @@ func.func @f3(%v: vector<2xi48>) { return } +func.func @print_as_i1_2xi8(%v : vector<2xi8>) { + %bitsi16 = vector.bitcast %v : vector<2xi8> to vector<16xi1> + vector.print %bitsi16 : vector<16xi1> + return +} + +func.func @print_as_i1_4xi4(%v : vector<4xi4>) { + %bitsi16 = vector.bitcast %v : vector<4xi4> to vector<16xi1> + vector.print %bitsi16 : vector<16xi1> + return +} + +func.func @ftrunc_splat(%v: vector<2xi24>) { + %trunc = arith.trunci %v : vector<2xi24> to vector<2xi8> + func.call @print_as_i1_2xi8(%trunc) : (vector<2xi8>) -> () + // CHECK: ( + // CHECK-SAME: 0, 1, 1, 1, 1, 1, 1, 1, + // CHECK-SAME: 1, 1, 0, 0, 0, 0, 1, 1 ) + + %bitcast = vector.bitcast %trunc : vector<2xi8> to vector<4xi4> + func.call @print_as_i1_4xi4(%bitcast) : (vector<4xi4>) -> () + // CHECK: ( + // CHECK-SAME: 0, 1, 1, 1, + // CHECK-SAME: 1, 1, 1, 1, + // CHECK-SAME: 1, 1, 0, 0, + // CHECK-SAME: 0, 0, 1, 1 ) + + return +} + func.func @print_as_i1_8xi5(%v : vector<8xi5>) { %bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1> vector.print %bitsi40 : vector<40xi1> @@ -164,6 +194,32 @@ func.func @fext(%a: vector<5xi8>) { return } +func.func @print_as_i1_4xi8(%v : vector<4xi8>) { + %bitsi32 = vector.bitcast %v : vector<4xi8> to vector<32xi1> + vector.print %bitsi32 : vector<32xi1> + return +} + +func.func @fext_splat(%a: vector<2xi8>) { + %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4> + func.call @print_as_i1_4xi4(%0) : (vector<4xi4>) -> () + // CHECK: ( + // CHECK-SAME: 0, 1, 1, 1, + // CHECK-SAME: 1, 1, 1, 1, + // CHECK-SAME: 1, 1, 0, 0, + // CHECK-SAME: 0, 0, 1, 1 ) + + %1 = arith.extui %0 : vector<4xi4> to vector<4xi8> + func.call @print_as_i1_4xi8(%1) : (vector<4xi8>) -> () + // CHECK: ( + // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0, + // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0, + // CHECK-SAME: 1, 1, 0, 0, 0, 0, 0, 0, + // CHECK-SAME: 0, 0, 1, 1, 0, 0, 0, 0 ) + + return +} + func.func @fcst_maskedload(%A: memref, %passthru: vector<6xi4>) -> vector<6xi4> { %c0 = arith.constant 0: index %mask = vector.constant_mask [3] : vector<6xi1> @@ -190,9 +246,19 @@ func.func @entry() { func.call @f3(%v3) : (vector<2xi48>) -> () %v4 = arith.constant dense<[ + 0xafe, 0xbc3 + ]> : vector<2xi24> + func.call @ftrunc_splat(%v4) : (vector<2xi24>) -> () + + %v5 = arith.constant dense<[ 0xef, 0xee, 0xed, 0xec, 0xeb ]> : vector<5xi8> - func.call @fext(%v4) : (vector<5xi8>) -> () + func.call @fext(%v5) : (vector<5xi8>) -> () + + %v6 = arith.constant dense<[ + 0xfe, 0xc3 + ]> : vector<2xi8> + func.call @fext_splat(%v6) : (vector<2xi8>) -> () // Set up memory. %c0 = arith.constant 0: index @@ -218,7 +284,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %f { - transform.apply_patterns.vector.rewrite_narrow_types + transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 } } : !transform.any_op transform.yield }