From 309abfe7f6558a61126827e9a0bd81e29273af10 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 12 Nov 2024 22:43:58 -0500 Subject: [PATCH] [MLIR] Fix VectorEmulateNarrowType constant op mask bug This commit adds support for handling mask constants generated by the `arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this pattern would not match due to the lack of mask constant handling in `getCompressedMaskOp`. The changes include: 1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as mask value sources. 2. Handling cases where the mask is not aligned with the emulated load width. The compressed mask is adjusted to account for the offset. Limitations: - The arith.constant op can only have 1-dimensional constant values. Resolves: #115742 Signed-off-by: Alan Li --- .../Transforms/VectorEmulateNarrowType.cpp | 169 +++++++++++------- .../vector-emulate-narrow-type-unaligned.mlir | 38 ++++ .../Vector/vector-emulate-narrow-type.mlir | 51 ++++++ 3 files changed, 198 insertions(+), 60 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index e5f2a847994ae..dc8bab325184b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -75,83 +75,134 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, int numSrcElemsPerDest, int numFrontPadElems = 0) { - assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale"); + assert(numFrontPadElems < numSrcElemsPerDest && + "numFrontPadElems must be less than numSrcElemsPerDest"); - auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) / - numSrcElemsPerDest; + auto numDestElems = + (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) / + numSrcElemsPerDest; Operation *maskOp = mask.getDefiningOp(); SmallVector extractOps; + // TODO: add support to `vector.splat`. // Finding the mask creation operation. - while (maskOp && !isa(maskOp)) { + while (maskOp && + !isa( + maskOp)) { if (auto extractOp = dyn_cast(maskOp)) { maskOp = extractOp.getVector().getDefiningOp(); extractOps.push_back(extractOp); } } - auto createMaskOp = dyn_cast_or_null(maskOp); - auto constantMaskOp = dyn_cast_or_null(maskOp); - if (!createMaskOp && !constantMaskOp) + + if (!isa( + maskOp)) return failure(); // Computing the "compressed" mask. All the emulation logic (i.e. computing // new mask index) only happens on the last dimension of the vectors. - Operation *newMask = nullptr; - SmallVector shape( + SmallVector maskShape( cast(maskOp->getResultTypes()[0]).getShape()); - shape.back() = numElements; - auto newMaskType = VectorType::get(shape, rewriter.getI1Type()); - if (createMaskOp) { - OperandRange maskOperands = createMaskOp.getOperands(); - size_t numMaskOperands = maskOperands.size(); - AffineExpr s0; - bindSymbols(rewriter.getContext(), s0); - s0 = s0 + numSrcElemsPerDest - 1; - s0 = s0.floorDiv(numSrcElemsPerDest); - OpFoldResult origIndex = - getAsOpFoldResult(maskOperands[numMaskOperands - 1]); - OpFoldResult maskIndex = - affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex); - SmallVector newMaskOperands(maskOperands.drop_back()); - newMaskOperands.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex)); - newMask = rewriter.create(loc, newMaskType, - newMaskOperands); - } else if (constantMaskOp) { - ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); - size_t numMaskOperands = maskDimSizes.size(); - int64_t origIndex = maskDimSizes[numMaskOperands - 1]; - int64_t startIndex = numFrontPadElems / numSrcElemsPerDest; - int64_t maskIndex = - llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest); - - // TODO: we only want the mask between [startIndex, maskIndex] to be true, - // the rest are false. - if (numFrontPadElems != 0 && maskDimSizes.size() > 1) - return failure(); - - SmallVector newMaskDimSizes(maskDimSizes.drop_back()); - newMaskDimSizes.push_back(maskIndex); - - if (numFrontPadElems == 0) { - newMask = rewriter.create(loc, newMaskType, - newMaskDimSizes); - } else { - SmallVector newMaskValues; - for (int64_t i = 0; i < numElements; ++i) - newMaskValues.push_back(i >= startIndex && i < maskIndex); - auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues); - newMask = rewriter.create(loc, newMaskType, denseAttr); - } - } + maskShape.back() = numDestElems; + auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type()); + std::optional newMask = + TypeSwitch>(maskOp) + .Case( + [&](auto createMaskOp) -> std::optional { + OperandRange maskOperands = createMaskOp.getOperands(); + size_t numMaskOperands = maskOperands.size(); + AffineExpr s0; + bindSymbols(rewriter.getContext(), s0); + s0 = s0 + numSrcElemsPerDest - 1; + s0 = s0.floorDiv(numSrcElemsPerDest); + OpFoldResult origIndex = + getAsOpFoldResult(maskOperands[numMaskOperands - 1]); + OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply( + rewriter, loc, s0, origIndex); + SmallVector newMaskOperands(maskOperands.drop_back()); + newMaskOperands.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex)); + return rewriter.create(loc, newMaskType, + newMaskOperands); + }) + .Case([&](auto constantMaskOp) + -> std::optional { + ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); + size_t numMaskOperands = maskDimSizes.size(); + int64_t origIndex = maskDimSizes[numMaskOperands - 1]; + int64_t startIndex = numFrontPadElems / numSrcElemsPerDest; + int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex, + numSrcElemsPerDest); + + // TODO: we only want the mask between [startIndex, maskIndex] + // to be true, the rest are false. + if (numFrontPadElems != 0 && maskDimSizes.size() > 1) + return std::nullopt; + + SmallVector newMaskDimSizes(maskDimSizes.drop_back()); + newMaskDimSizes.push_back(maskIndex); + + if (numFrontPadElems == 0) + return rewriter.create(loc, newMaskType, + newMaskDimSizes); + + SmallVector newMaskValues; + for (int64_t i = 0; i < numDestElems; ++i) + newMaskValues.push_back(i >= startIndex && i < maskIndex); + auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues); + return rewriter.create(loc, newMaskType, + newMask); + }) + .Case([&](auto constantOp) + -> std::optional { + // TODO: Support multiple dimensions. + if (maskShape.size() != 1) + return std::nullopt; + // Rearrange the original mask values to cover the whole potential + // loading region. For example, in the case of using byte-size for + // emulation, given the following mask: + // + // %mask = [0, 1, 0, 1, 0, 0] + // + // With front offset of 1, the mask will be padded 0s in the front + // and back so that: + // 1. It is aligned with the effective loading bits + // 2. Its length is multiple of `numSrcElemPerDest` (and the total + // coverage size is mulitiple of bytes). The new mask will be like + // this before compressing: + // + // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0] + auto originalMask = + cast(constantOp.getValue()); + SmallVector paddedMaskValues(numFrontPadElems, false); + paddedMaskValues.append(originalMask.template value_begin(), + originalMask.template value_end()); + paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false); + + // Compressing by combining every `numSrcElemsPerDest` elements: + SmallVector compressedMaskValues; + for (size_t i = 0; i < paddedMaskValues.size(); + i += numSrcElemsPerDest) { + bool combinedValue = false; + for (int j = 0; j < numSrcElemsPerDest; ++j) { + combinedValue |= paddedMaskValues[i + j]; + } + compressedMaskValues.push_back(combinedValue); + } + return rewriter.create( + loc, DenseElementsAttr::get(newMaskType, compressedMaskValues)); + }); + + if (!newMask) + return failure(); while (!extractOps.empty()) { newMask = rewriter.create( - loc, newMask->getResults()[0], extractOps.back().getMixedPosition()); + loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition()); extractOps.pop_back(); } - return newMask; + return *newMask; } /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for @@ -185,12 +236,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, /// `vector.insert_strided_slice`. static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset) { - auto srcType = cast(src.getType()); - auto destType = cast(dest.getType()); + [[maybe_unused]] auto srcType = cast(src.getType()); + [[maybe_unused]] auto destType = cast(dest.getType()); assert(srcType.getRank() == 1 && destType.getRank() == 1 && "expected source and dest to be vector type"); - (void)srcType; - (void)destType; auto offsets = rewriter.getI64ArrayAttr({offset}); auto strides = rewriter.getI64ArrayAttr({1}); return rewriter.create(loc, dest.getType(), src, diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 7ed75ff7f1579..b1a0d4f924f3c 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -249,3 +249,41 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2> // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2> // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2> + +// ----- + +func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> { + %0 = memref.alloc() : memref<3x5xi2> + %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru : + memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + return %1 : vector<5xi2> +} + +// CHECK: func @vector_maskedload_i4_constant_mask_unaligned( +// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2> +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8> +// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1> + +// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense : vector<2xi1> +// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2> +// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2> + +// Emulated masked load from alloc: +// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]] +// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2> + +// Select from emulated loaded vector and passthru vector: +// TODO: fold this part if possible. +// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense : vector<8xi1> +// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1> +// CHECK: %[[SELECT:.+]] = arith.select %[[MASK_PADDED]], %[[MASKLOAD_DOWNCAST]], %[[PTH_PADDED]] : vector<8xi1>, vector<8xi2> +// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]] +// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> +// CHECK: return %[[RESULT]] : vector<5xi2> diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 034bd47f6163e..7a3ba95893383 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -275,6 +275,30 @@ func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passt // ----- +func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> { + %0 = memref.alloc() : memref<3x8xi4> + %cst = arith.constant dense<0> : vector<8xi4> + %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1> + %c0 = arith.constant 0 : index + %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru : + memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4> + return %1 : vector<8xi4> +} + +// CHECK: func @vector_maskedload_i4_arith_constant( +// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]] +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1> + +// Emit a new, compressed mask for emulated maskedload: +// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> +// CHECK: %[[PTHU_UPCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]]], %[[COMPRESSED_MASK]], %[[PTHU_UPCAST]] +// CHECK: %[[LOAD_DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> +// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[LOAD_DOWNCAST]], %[[PASSTHRU]] : vector<8xi1>, vector<8xi4> +// CHECK: return %[[SELECT]] : vector<8xi4> + ///---------------------------------------------------------------------------------------- /// vector.extract -> vector.masked_load ///---------------------------------------------------------------------------------------- @@ -624,3 +648,30 @@ func.func @vector_maskedstore_i4_constant_mask( // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> // CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32> + +// ----- + +func.func @vector_maskedstore_i4_arith_constant(%val_to_store: vector<8xi4>) { + %0 = memref.alloc() : memref<5x8xi4> + %cst = arith.constant dense<0> : vector<8xi4> + %mask = arith.constant dense<[false, true, true, true, true, true, false, false]> : vector<8xi1> + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + vector.maskedstore %0[%c3, %c0], %mask, %val_to_store : + memref<5x8xi4>, vector<8xi1>, vector<8xi4> + return +} + +// CHECK-LABEL: func @vector_maskedstore_i4_arith_constant +// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<20xi8> +// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, true, false, false]> : vector<8xi1> +// %c3 * 4 bits = 12 +// CHECK: %[[IDX_FLATTENED:.+]] = arith.constant 12 : index +// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> +// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<4xi8> +// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[IDX_FLATTENED]]], %[[COMPRESSED_MASK]], %[[EMPTY]] +// CHECK: %[[LOAD_UPCAST:.+]] = vector.bitcast %[[MASKEDLOAD]] +// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[VAL_TO_STORE]], %[[LOAD_UPCAST]] +// CHECK: %[[SELECT_DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.maskedstore %[[ALLOC]][%[[IDX_FLATTENED]]], %[[COMPRESSED_MASK]], %[[SELECT_DOWNCAST]]