diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index dc8bab325184b..87c30a733c363 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -110,13 +110,16 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, .Case( [&](auto createMaskOp) -> std::optional { OperandRange maskOperands = createMaskOp.getOperands(); - size_t numMaskOperands = maskOperands.size(); + // The `vector.create_mask` op creates a mask arrangement + // without any zeros at the front. Also, because + // `numFrontPadElems` is strictly smaller than + // `numSrcElemsPerDest`, the compressed mask generated by + // padding the original mask by `numFrontPadElems` will not + // have any zeros at the front as well. AffineExpr s0; bindSymbols(rewriter.getContext(), s0); - s0 = s0 + numSrcElemsPerDest - 1; - s0 = s0.floorDiv(numSrcElemsPerDest); - OpFoldResult origIndex = - getAsOpFoldResult(maskOperands[numMaskOperands - 1]); + s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest); + OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back()); OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply( rewriter, loc, s0, origIndex); SmallVector newMaskOperands(maskOperands.drop_back()); diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index b1a0d4f924f3c..721c8a8d5d203 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -42,7 +42,7 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> { // ----- -func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> { +func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> { %0 = memref.alloc() : memref<3x5xi2> %cst = arith.constant dense<0> : vector<3x5xi2> %mask = vector.constant_mask [3] : vector<5xi1> @@ -54,7 +54,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> { return %2 : vector<3x5xi2> } -// CHECK-LABEL: func @vector_cst_maskedload_i2( +// CHECK-LABEL: func @vector_constant_mask_maskedload_i2( // CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2> // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1> // CHECK: %[[NEWMASK:.+]] = arith.constant dense : vector<2xi1> @@ -74,6 +74,55 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> { // ----- +// This tests the correctness of generating compressed mask with `vector.create_mask` on a static input and dynamic indices. +// Specifically, the program masked loads a vector<5xi2> from `vector<3x5xi2>[1, 0]`, with an unknown mask generator `m`. +// After emulation transformation, it masked loads 2 bytes from linearized index `vector<4xi8>[1]`, with a new compressed mask +// given by `ceildiv(m + 1, 4)`. +func.func @unaligned_create_mask_dynamic_i2(%m : index, %passthru: vector<5xi2>) -> vector<5xi2> { + %0 = memref.alloc() : memref<3x5xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mask = vector.create_mask %m : vector<5xi1> + %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru : + memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + return %1 : vector<5xi2> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) ceildiv 4)> +// CHECK: func @unaligned_create_mask_dynamic_i2( +// CHECK-SAME: %[[NUM_ELEMS_TO_LOAD:.+]]: index, %[[PASSTHRU:.+]]: vector<5xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8> +// CHECK: %[[COMPRESSED_MASK:.+]] = affine.apply #map()[%[[NUM_ELEMS_TO_LOAD]]] +// CHECK: vector.create_mask %[[COMPRESSED_MASK]] : vector<2xi1> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: vector.maskedload %[[ALLOC]][%[[C1]]] + +// ----- + +// This tests the correctness of generated compressed mask with `vector.create_mask`, and a static input. +// Quite the same as the previous test, but the mask generator is a static value. +// In this case, the desired slice `vector<7xi2>` spans over 3 bytes. +func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vector<7xi2> { + %0 = memref.alloc() : memref<3x7xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %mask = vector.create_mask %c3 : vector<7xi1> + %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru : + memref<3x7xi2>, vector<7xi1>, vector<7xi2> into vector<7xi2> + return %1 : vector<7xi2> +} + +// CHECK: func @check_unaligned_create_mask_static_i2( +// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<7xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[COMP_MASK:.+]] = vector.create_mask %[[C2]] : vector<3xi1> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %4 = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMP_MASK]] + +// ----- + func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> { %0 = memref.alloc() : memref<3x3xi2> %cst = arith.constant dense<0> : vector<3x3xi2> diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 7a3ba95893383..9dc3eb6989c6c 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -141,7 +141,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt // CHECK-NEXT: return // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)> -// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)> +// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> // CHECK32: func @vector_maskedload_i8( // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>) @@ -169,7 +169,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt return %2 : vector<3x8xi4> } // CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> -// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)> +// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> // CHECK: func @vector_maskedload_i4( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>) @@ -185,7 +185,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4> // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> -// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)> +// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> // CHECK32: func @vector_maskedload_i4( // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>) @@ -497,7 +497,7 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu // CHECK-NEXT: return // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)> -// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)> +// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> // CHECK32: func @vector_maskedstore_i8( // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -530,7 +530,7 @@ func.func @vector_maskedstore_i4( return } // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> -// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)> +// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> // CHECK-LABEL: func.func @vector_maskedstore_i4( // CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, @@ -550,7 +550,7 @@ func.func @vector_maskedstore_i4( // CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8> // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> -// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)> +// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> // CHECK32-LABEL: func.func @vector_maskedstore_i4( // CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,