From 18cbf37b552a10c28564ef686e390e9c504face2 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 1 May 2025 14:48:35 +0100 Subject: [PATCH 1/2] [mlir][vector] Refactor `createWriteOrMaskedWrite` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch updates `createWriteOrMaskedWrite` to make it consistent with `createReadOrMaskedRead`. Before diving into the details: note that these utilities are currently implemented in different files — "VectorUtils.cpp" (Vector) and "Vectorization.cpp" (Linalg). In a subsequent patch, I plan to move `createWriteOrMaskedWrite` into "VectorUtils.cpp". SUMMARY OF CHANGES: The main change is to remove the logic that creates the destination tensor, which previously looked like: ```cpp Value dest = builder.create(loc, destSizes, inputType.getElementType()); ``` With this patch, createWriteOrMaskedWrite now simply generates: ```mlir %res = vector.transfer_write %vectorToStore into %dest ``` This replaces the previous form: ```mlir %dest = tensor.empty(%destSizes) %res = vector.transfer_write %vectorToStore into %dest ``` In other words, the destination value `%dest` is now passed as an input parameter. This makes `createWriteOrMaskedWrite` re-usable in contexts where the destination tensor is already known — for example, in `vectorizeAsInsertSliceOp`, which I will update in a follow-up patch. OTHER CHANGES: * Added comments and clarified TODOs. * Updated tests: since destination sizes are now computed independently inside `createWriteOrMaskedWrite`, some additional `tensor.dim` ops appear. These will be cleaned up by CSE + canonicalization. --- .../Linalg/Transforms/Vectorization.cpp | 98 +++++++++++-------- mlir/test/Dialect/Linalg/vectorization.mlir | 8 +- 2 files changed, 61 insertions(+), 45 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a477c2fb3f8cb..12ecdf9494bef 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1506,72 +1506,68 @@ static SmallVector getTiledPackShape(linalg::PackOp packOp, return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp)); } -/// Creates a TransferWriteOp to write `input` into a newly initialized -/// output tensor. +/// Creates an optionally masked TransferWriteOp /// -/// Given: -/// - an input vector to write, -/// - the mixed destination sizes for the output tensor, -/// - and the vector sizes used for vectorization (i.e., the leading N dims, -/// for some value of N), -/// -/// this function generates the following sequence of ops: -/// -/// %dest = tensor.empty(%destSizes) -/// %res = vector.transfer_write %input into %dest +/// Generates the following operation: +/// %res = vector.transfer_write %vectorToStore into %dest /// /// If the leading N dimensions of the destination tensor do not match -/// `inputVecSizesForLeadingDims` (where N = -/// rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure -/// correctness: +/// `inputVecSizesForLeadingDims`, where= +/// * N = rank(`inputVecSizesForLeadingDims`)), +/// masking is applied to ensure correctness: /// -/// %dest = tensor.empty(%destSizes) -/// %write = vector.transfer_write %input into %dest -/// %mask = vector.create_mask(%destSizes) +/// %write = vector.transfer_write %vectorToStore into %dest +/// %mask = vector.create_mask(%destShape) /// %res = vector.mask %mask { %write } /// /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute /// is used instead of masking: /// -/// %dest = tensor.empty(%destSizes) +/// %write = vector.transfer_write %vectorToStore into %dest /// in_bounds_flags = (...) /// %res = vector.transfer_write %input into %dest /// {in_bounds = in_bounds_flags} /// -/// NOTE: all write offsets are set to 0. +/// NOTE: All write offsets are set to 0. +/// TODO: Allow specyfying write offsets. /// NOTE: When N < rank(input), the missing vector sizes are effectively /// extracted from the trailing sizes of `destSizes`. This means those sizes -/// must be static. Supporting dynamic sizes will require the user to specify -/// the remaining vector sizes. This is left as a TODO. +/// must be static. +/// TODO: Support cases where an arbitrary dim is dynamic - this will require +/// specifying all the vector sizes. static Operation * -createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, - SmallVector destSizes, +createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore, + Value dest, ArrayRef inputVecSizesForLeadingDims, bool useInBoundsInsteadOfMasking = false) { - auto inputType = cast(input.getType()); - assert(inputType.getRank() == static_cast(destSizes.size()) && + ShapedType destType = cast(dest.getType()); + assert(cast(vectorToStore.getType()).getRank() == + static_cast(destType.getRank()) && "Rank mismatch!"); - Value dest = builder.create(loc, destSizes, - inputType.getElementType()); int64_t rank = cast(dest.getType()).getRank(); - auto zero = builder.create(loc, 0); auto destShape = cast(dest.getType()).getShape(); + + // Compute the in_bounds attribute SmallVector inBoundsVal(rank, true); if (useInBoundsInsteadOfMasking) { // In this case, assume that all the required vector sizes have been // provided. - assert(inputVecSizesForLeadingDims.size() == destSizes.size() && + assert(inputVecSizesForLeadingDims.size() == + static_cast(destType.getRank()) && "Insufficient number of input vector sizes!"); // Update the inBounds attribute. for (unsigned i = 0; i < rank; i++) inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) && !ShapedType::isDynamic(destShape[i]); } + + // Generate the xfer_write Op + auto zero = builder.create(loc, 0); Operation *write = builder.create( loc, - /*vector=*/input, + /*vector=*/vectorToStore, /*source=*/dest, /*indices=*/SmallVector(rank, zero), /*inBounds=*/inBoundsVal); @@ -1579,11 +1575,17 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, destShape.drop_front(inputVecSizesForLeadingDims.size()), [](int64_t size) { return size == ShapedType::kDynamic; }) && "Only dims aligned with inputVecSizesForLeadingDims may be dynamic"); + + // If masking is disabled, exit. if (useInBoundsInsteadOfMasking) return write; + + // Check if masking is needed. bool needMaskForWrite = !llvm::equal(inputVecSizesForLeadingDims, destShape.take_front(inputVecSizesForLeadingDims.size())); + + // If masking is needed, generate the mask and mask the operation. if (needMaskForWrite) { SmallVector writeMaskShape; writeMaskShape.append(inputVecSizesForLeadingDims.begin(), @@ -1592,10 +1594,11 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, inputVecSizesForLeadingDims.size(), destShape.end()); auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); - Value maskForWrite = - builder.create(loc, writeMaskType, destSizes); + Value maskForWrite = builder.create( + loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest)); write = mlir::vector::maskOperation(builder, write, maskForWrite); } + return write; } @@ -1693,9 +1696,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, loc, shapeCastOp.getResult(), destPermutation); // Create TransferWriteOp. + Value dest = rewriter.create( + loc, reifiedReturnShapes[0], + transposeOp.getResult().getType().getElementType()); Operation *write = - createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), - /*destSizes=*/reifiedReturnShapes[0], + createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest, /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false); newResults.push_back(write->getResult(0)); @@ -1830,10 +1835,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, unpackOp.getDestType().hasStaticShape() ? vectorSizes : shapeCastOp.getResultVectorType().getShape()); - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, shapeCastOp.getResult(), /*destSizes=*/reifiedRetShapes[0], - /*inputVecSizesForLeadingDims=*/writeVectorSizes, - useInBoundsInsteadOfMasking); + Value dest = rewriter.create( + loc, reifiedRetShapes[0], + shapeCastOp.getResult().getType().getElementType()); + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest, + /*inputVecSizesForLeadingDims=*/writeVectorSizes, + useInBoundsInsteadOfMasking); newResults.push_back(write->getResult(0)); return success(); } @@ -1861,10 +1869,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, auto maskedRead = vector::createReadOrMaskedRead( rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, /*useInBoundsInsteadOfMasking=*/false); - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, maskedRead, reifiedReturnShapes[0], - /*inputVecSizesForLeadingDims=*/inputVectorSizes, - /*useInBoundsInsteadOfMasking=*/false); + + // Create Xfer write Op + Value dest = rewriter.create( + loc, reifiedReturnShapes[0], padOp.getResultType().getElementType()); + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, + /*inputVecSizesForLeadingDims=*/inputVectorSizes, + /*useInBoundsInsteadOfMasking=*/false); newResults.push_back(write->getResult(0)); return success(); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 299be1296aa66..6b760a15afd56 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -641,7 +641,9 @@ func.func @test_masked_vectorize_dynamic_pad( // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor // CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index - // CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1> + // CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor + // CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor + // CHECK: %[[mask_2:.*]] = vector.create_mask %[[d2]], %[[d3]] : vector<2x4xi1> // CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] { // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor @@ -800,7 +802,9 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor, %arg1: tensor -// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1> +// CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor +// CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor +// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d2]], %[[d3]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1> // CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] { // CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]] // CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor From 2cd73e6a3f338467820eeec5be277d97a32d85f3 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Wed, 14 May 2025 09:56:30 +0100 Subject: [PATCH 2/2] fixup! [mlir][vector] Refactor `createWriteOrMaskedWrite` Fix comments --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 12ecdf9494bef..1dea0c8292e67 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1512,13 +1512,13 @@ static SmallVector getTiledPackShape(linalg::PackOp packOp, /// %res = vector.transfer_write %vectorToStore into %dest /// /// If the leading N dimensions of the destination tensor do not match -/// `inputVecSizesForLeadingDims`, where= -/// * N = rank(`inputVecSizesForLeadingDims`)), +/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)), /// masking is applied to ensure correctness: /// -/// %write = vector.transfer_write %vectorToStore into %dest /// %mask = vector.create_mask(%destShape) -/// %res = vector.mask %mask { %write } +/// %res = vector.mask %mask { +/// vector.transfer_write %vectorToStore into %dest +/// } /// /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute /// is used instead of masking: