diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 87c30a733c363..181c394edc1d2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -128,34 +128,17 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, return rewriter.create(loc, newMaskType, newMaskOperands); }) - .Case([&](auto constantMaskOp) - -> std::optional { - ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); - size_t numMaskOperands = maskDimSizes.size(); - int64_t origIndex = maskDimSizes[numMaskOperands - 1]; - int64_t startIndex = numFrontPadElems / numSrcElemsPerDest; - int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex, - numSrcElemsPerDest); - - // TODO: we only want the mask between [startIndex, maskIndex] - // to be true, the rest are false. - if (numFrontPadElems != 0 && maskDimSizes.size() > 1) - return std::nullopt; - - SmallVector newMaskDimSizes(maskDimSizes.drop_back()); - newMaskDimSizes.push_back(maskIndex); - - if (numFrontPadElems == 0) - return rewriter.create(loc, newMaskType, - newMaskDimSizes); - - SmallVector newMaskValues; - for (int64_t i = 0; i < numDestElems; ++i) - newMaskValues.push_back(i >= startIndex && i < maskIndex); - auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues); - return rewriter.create(loc, newMaskType, - newMask); - }) + .Case( + [&](auto constantMaskOp) -> std::optional { + // Take the shape of mask, compress its trailing dimension: + SmallVector maskDimSizes( + constantMaskOp.getMaskDimSizes()); + int64_t &maskIndex = maskDimSizes.back(); + maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, + numSrcElemsPerDest); + return rewriter.create(loc, newMaskType, + maskDimSizes); + }) .Case([&](auto constantOp) -> std::optional { // TODO: Support multiple dimensions. @@ -604,7 +587,6 @@ struct ConvertVectorMaskedLoad final LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // See #115653 if (op.getVectorType().getRank() != 1) return rewriter.notifyMatchFailure(op, diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 721c8a8d5d203..4332e80feed42 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -42,22 +42,20 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> { // ----- -func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> { +func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<5xi2> { %0 = memref.alloc() : memref<3x5xi2> - %cst = arith.constant dense<0> : vector<3x5xi2> %mask = vector.constant_mask [3] : vector<5xi1> %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru : memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> - %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2> - return %2 : vector<3x5xi2> + return %1 : vector<5xi2> } - // CHECK-LABEL: func @vector_constant_mask_maskedload_i2( -// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2> +// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<5xi2> +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8> // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1> -// CHECK: %[[NEWMASK:.+]] = arith.constant dense : vector<2xi1> +// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1> // CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2> // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]] // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2> @@ -123,6 +121,29 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec // ----- +// This test is similar to @vector_constant_mask_maskedload_i2, but the mask is multi-dimensional. +func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> { + %0 = memref.alloc() : memref<4x3x5xi2> + %mask = vector.constant_mask [2, 2] : vector<3x5xi1> + %ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru : + memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + return %1 : vector<5xi2> +} + +// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim( +// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1> +// CHECK: vector.extract %[[ORIG_MASK]][1] + +// Compressing the mask used for emulated masked load. +// The innermost dimension is compressed to 2 elements from 5. +// CHECK: %[[NEW_COMPRESSED_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1> +// CHECK: vector.extract %[[NEW_COMPRESSED_MASK]][1] + +// ----- + func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> { %0 = memref.alloc() : memref<3x3xi2> %cst = arith.constant dense<0> : vector<3x3xi2> @@ -252,7 +273,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1> // CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]] // CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]] -// CHECK: %[[ONE:.+]] = arith.constant dense : vector<2xi1> +// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2> // Extract passthru vector, and insert into zero vector, this is for constructing a new passthru @@ -301,7 +322,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // ----- -func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> { +func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> { %0 = memref.alloc() : memref<3x5xi2> %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1> %c0 = arith.constant 0 : index @@ -311,24 +332,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) return %1 : vector<5xi2> } -// CHECK: func @vector_maskedload_i4_constant_mask_unaligned( +// CHECK: func @vector_maskedload_i2_constant_mask_unaligned( // CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2> // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8> // CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1> +// Emulated masked load from alloc: // CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense : vector<2xi1> // CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2> // CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]] // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2> - -// Emulated masked load from alloc: // CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]] // CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2> // Select from emulated loaded vector and passthru vector: -// TODO: fold this part if possible. +// TODO: fold insert_strided_slice into source if possible. // CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense : vector<8xi1> // CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]] // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>