diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index d04f302200519..a674a59009181 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1090,15 +1090,20 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); - // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now. - if (srcElemBitwidth != 4 || dstElemBitwidth < 8 || - (dstElemBitwidth % srcElemBitwidth) != 0) - return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); + if (dstElemBitwidth < 8) + return rewriter.notifyMatchFailure( + op, "the bitwidth of dstType must be greater than or equal to 8"); + if (dstElemBitwidth % srcElemBitwidth != 0) + return rewriter.notifyMatchFailure(op, "unaligned cases are not supported"); + if (srcElemBitwidth != 2 && srcElemBitwidth != 4) + return rewriter.notifyMatchFailure( + op, "only src bitwidth of 2 or 4 is supported at this moment"); - const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; - if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0) + const int numSrcElemsPerByte = 8 / srcElemBitwidth; + if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0) return rewriter.notifyMatchFailure( - op, "Not an even number of i4 elements in trailing dim"); + op, "the trailing dimension of the input vector of sub-bytes must be a " + "multiple of 8 / "); return success(); } @@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep( return runningResult; } -/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and -/// bitwise ops that take advantage of high-level information to avoid leaving -/// LLVM to scramble with peephole optimizations. -static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, - Value srcValue) { - VectorType srcVecType = cast(srcValue.getType()); - assert(srcVecType.getElementType().isSignlessInteger(4) && - "Expected i4 type"); +/// Bitcasts the aligned `subByteVec` vector to a vector of i8. +/// Where aligned means it satisfies the alignedConversionPreconditions. +/// +/// Example: +/// vector<16x16xi2> -> vector<16x4xi8> +/// vector<16x16xi4> -> vector<16x8xi8> +static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, + Value subByteVec) { + auto srcVecType = cast(subByteVec.getType()); + int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth(); + assert(8 % srcBitwidth == 0 && + "Unsupported sub-byte type (not a divisor of i8)"); + int64_t numSrcElemsPerByte = 8 / srcBitwidth; + SmallVector vecShape(srcVecType.getShape()); + // Adjust last dimension of the vector, so the total size remains the same. + vecShape.back() = vecShape.back() / numSrcElemsPerByte; + auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type()); + return rewriter.create(loc, i8VecType, subByteVec); +} - // 1. Generate a bitcast vector -> vector. - SmallVector i8VecShape = llvm::to_vector(srcVecType.getShape()); - constexpr int64_t i4Toi8BitwidthFactor = 2; - i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor; - auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type()); - Value i8Vector = rewriter.create(loc, i8VecType, srcValue); +/// Extracts a signed N-bit sequence from each element of a vector of bytes, +/// starting at the specified bit index. +/// The `bitIdx` starts at 0 from the LSB and moves to the left. +/// +/// Example for a single element: +/// Extract numBits=2 starting at bitIdx=2 +/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0] +/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0] +/// target = [. . . . ^ ^ . .] +/// +/// The target sequence is [11](decimal=-1) as signed 2-bit integer. +/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer. +/// +/// src = [01 01 11 10] +/// shl = arith.shl(src, 4) -> [11 10 00 00] +/// result = arith.shrsi(shl, 6) -> [11 11 11 11] +static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, + Location loc, Value src, + int bitIdx, int numBits) { + auto srcType = cast(src.getType()); + Value shl = src; + int8_t bitsToShiftLeft = 8 - numBits - bitIdx; + assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 && + "Invalid bitIdx range"); + if (bitsToShiftLeft != 0) { + Value shiftLeftValues = rewriter.create( + loc, DenseElementsAttr::get(srcType, bitsToShiftLeft)); + shl = rewriter.create(loc, src, shiftLeftValues); + } - // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each - // byte are place in one vector and the high i4 elements in another vector. - constexpr int8_t bitsToShift = 4; - auto shiftValues = rewriter.create( - loc, DenseElementsAttr::get(i8VecType, bitsToShift)); - Value shl = rewriter.create(loc, i8Vector, shiftValues); - Value low = rewriter.create(loc, shl, shiftValues); - Value high = rewriter.create(loc, i8Vector, shiftValues); + int8_t bitsToShiftRight = 8 - numBits; + Value shiftRightValues = rewriter.create( + loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); + Value shr = rewriter.create(loc, shl, shiftRightValues); + return shr; +} - // 3. Interleave low and high i8 elements. - return rewriter.create(loc, low, high); +/// Extracts an unsigned N-bit sequence from each element of a vector of bytes, +/// starting at the specified bit index. +/// The `bitIdx` starts at 0 from the LSB and moves to the left. +/// +/// Example for a single element: +/// Extract numBits=2 starting at bitIdx=2 +/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0] +/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0] +/// target = [. . . . ^ ^ . .] +/// +/// The target sequence is [10](decimal=2) as unsigned 2-bit integer. +/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer. +/// +/// src = [01 01 10 10] +/// mask = [00 00 00 11] +/// shr = arith.shrui(src, 2) = [00 01 01 10] +/// result = arith.andi(shr, mask) = [00 00 00 10] +/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be +/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking. +/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift +/// left when the index is 0. +static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, + Location loc, Value src, + int bitIdx, int numBits) { + assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 && + "Invalid bitIdx range"); + auto srcType = cast(src.getType()); + int8_t bitsToShiftRight = bitIdx; + Value shr = src; + if (bitsToShiftRight != 0) { + Value shiftRightValues = rewriter.create( + loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); + shr = rewriter.create(loc, src, shiftRightValues); + } + if (bitIdx + numBits == 8) { + return shr; + } + uint8_t lowBitsMask = (1 << numBits) - 1; + Value lowBitsMaskValues = rewriter.create( + loc, DenseElementsAttr::get(srcType, lowBitsMask)); + return rewriter.create(loc, shr, lowBitsMaskValues); } -/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and -/// bitwise ops that take advantage of high-level information to avoid leaving -/// LLVM to scramble with peephole optimizations. -static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, - Value srcValue) { - VectorType srcVecType = cast(srcValue.getType()); +using ExtractNBitsFn = + std::function; + +/// Rewrite the i4 -> i8 extension into a sequence of shuffles and +/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations. +static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, + Value srcValue, const ExtractNBitsFn &extFn) { + auto srcVecType = cast(srcValue.getType()); assert(srcVecType.getElementType().isSignlessInteger(4) && "Expected i4 type"); // 1. Generate a bitcast vector -> vector. - SmallVector i8VecShape = llvm::to_vector(srcVecType.getShape()); - constexpr int64_t i4Toi8BitwidthFactor = 2; - i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor; - auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type()); - Value i8Vector = rewriter.create(loc, i8VecType, srcValue); - - // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each - // byte are placed in one vector and the high i4 elements in another vector. - constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask - auto lowBitsMaskValues = rewriter.create( - loc, DenseElementsAttr::get(i8VecType, lowBitsMask)); - Value low = rewriter.create(loc, i8VecType, i8Vector, - lowBitsMaskValues); - constexpr int8_t highBitsToShift = 4; - auto highShiftValues = rewriter.create( - loc, DenseElementsAttr::get(i8VecType, highBitsToShift)); - Value high = rewriter.create(loc, i8Vector, highShiftValues); + Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue); + + // 2. Extend i4 elements to i8 elements. Low i4 elemens of each + // byte are place in one vector and the high i4 elements in another vector. + Value low = extFn(rewriter, loc, i8Vector, 0, 4); + Value high = extFn(rewriter, loc, i8Vector, 4, 4); // 3. Interleave low and high i8 elements. return rewriter.create(loc, low, high); } +/// Rewrite the i2 -> i8 extension into a sequence of shuffles and +/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations. +static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, + Value srcValue, const ExtractNBitsFn &extFn) { + VectorType srcVecType = cast(srcValue.getType()); + assert(srcVecType.getElementType().isSignlessInteger(2) && + "Expected i2 type"); + + // 1. Generate a bitcast vector -> vector. + Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue); + + // 2. Extract each i2 element + // Positon 0 (bits 0-1) + Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2); + // Position 1 (bits 2-3) + Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2); + // Position 2 (bits 4-5) + Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2); + // Position 3 (bits 6-7) + Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2); + + // 3. Interleave all 4 elements by first interleaving + // even elements and then odd + // vec0 = [0,0,0,0],... + // vec1 = [1,1,1,1],... + // vec2 = [2,2,2,2],... + // vec3 = [3,3,3,3],... + // 02 = [0,2,0,2,0,2,0,2],... + // 13 = [1,3,1,3,1,3,1,3],... + // 0213 = [0,1,2,3,...],... + Value interleave02 = rewriter.create(loc, vec0, vec2); + Value interleave13 = rewriter.create(loc, vec1, vec3); + return rewriter.create(loc, interleave02, interleave13); +} + /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise -/// ops that take advantage of high-level information to avoid leaving LLVM to -/// scramble with peephole optimizations. +/// ops to avoid leaving LLVM to scramble with peephole optimizations. static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue) { VectorType srcVecType = cast(srcValue.getType()); @@ -1443,13 +1544,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { return failure(); // Perform the rewrite. + Location loc = conversionOp.getLoc(); + const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8 + : extractNBitsPerByteAndExtendToI8; Value subByteExt; - if (isSigned) { - subByteExt = - rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue); - } else { - subByteExt = - rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue); + switch (srcVecType.getElementType().getIntOrFloatBitWidth()) { + case 2: + subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn); + break; + case 4: + subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn); + break; + default: + return failure(); } // Finalize the rewrite. @@ -1490,6 +1597,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp))) return failure(); + // TODO: Add support for truncating to i2. + if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2) + return failure(); + // Check general alignment preconditions. We invert the src/dst type order // to reuse the existing precondition logic. if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType, diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 210025e30d7db..8d28f248e392d 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -193,6 +193,25 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { return %1 : vector<8xi17> } + +// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4). +// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8( +func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> { + // CHECK-NOT: arith.bitcast + // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8> + %0 = arith.extsi %a : vector<1xi4> to vector<1xi8> + return %0 : vector<1xi8> +} + +// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2). +// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8( +func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> { + // CHECK-NOT: arith.bitcast + // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8> + %0 = arith.extsi %a : vector<2xi2> to vector<2xi8> + return %0 : vector<2xi8> +} + // CHECK-LABEL: func.func @aligned_extsi_i4_to_i8( func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { @@ -206,6 +225,31 @@ func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { return %0 : vector<8xi8> } +// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8( +func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> { +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> + %0 = arith.extsi %a : vector<8xi2> to vector<8xi8> + return %0 : vector<8xi8> +} + // CHECK-LABEL: func.func @aligned_extsi_i4_to_i32( func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { @@ -220,8 +264,34 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { return %0 : vector<8xi32> } -// CHECK-LABEL: func.func @aligned_extsi_2d( -func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32( +func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> { +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> +// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extsi %a : vector<8xi2> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d( +func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { // CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> @@ -234,6 +304,32 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { return %0 : vector<8x32xi32> } +// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d( +func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8> +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8> +// Extract bits 0-1 +// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> +// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8> +// Extract bits 2-3 +// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8> +// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8> +// Extract bits 4-5 +// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8> +// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8> +// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + // CHECK-LABEL: func.func @aligned_trunci_i8_to_i4( func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> { @@ -292,6 +388,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> { return %0 : vector<3x8x32xi4> } +func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> { + // CHECK-NOT: arith.bitcast + // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2> + %0 = arith.trunci %a : vector<8xi8> to vector<8xi2> + return %0 : vector<8xi2> +} + // CHECK-LABEL: func.func @aligned_extui_i4_to_i8( func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { @@ -305,6 +408,31 @@ func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { return %0 : vector<8xi8> } +// CHECK-LABEL: func.func @aligned_extui_i2_to_i8( +func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> { +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> + %0 = arith.extui %a : vector<8xi2> to vector<8xi8> + return %0 : vector<8xi8> +} + // CHECK-LABEL: func.func @aligned_extui_i4_to_i32( func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { @@ -319,8 +447,34 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { return %0 : vector<8xi32> } -// CHECK-LABEL: func.func @aligned_extui_2d( -func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-LABEL: func.func @aligned_extui_i2_to_i32( +func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> { +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> +// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extui %a : vector<8xi2> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d( +func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { // CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> // CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> @@ -333,6 +487,32 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { return %0 : vector<8x32xi32> } +// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d( +func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8> +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8> +// Extract bits 0-1 +// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8> +// Extract bits 2-3 +// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8> +// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8> +// Extract bits 4-5 +// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8> +// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8> +// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + // CHECK-LABEL: func.func @aligned_sitofp( func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {