From d661413b787a15a91f772e2570333aad95166a68 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 4 Sep 2025 17:19:24 -0700 Subject: [PATCH 01/21] [mlir][vector] Add support for lowering n-D vector.to_elements op. The revision adds a pattern that flattens 2 or more dimensional `vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`. It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests. It recovers the e2e lowering breakage from https://github.com/llvm/llvm-project/commit/b4c31dc98dfc929728904cd96f0f4cf812c4d5b5 on LLVM path. Signed-off-by: hanhanW --- .../Vector/Transforms/LoweringPatterns.h | 6 +++ .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 + .../Dialect/Vector/Transforms/CMakeLists.txt | 1 + .../Transforms/LowerVectorToElements.cpp | 52 +++++++++++++++++++ .../VectorToLLVM/vector-to-llvm.mlir | 40 ++++++++++++++ .../Vector/vector-to-elements-lowering.mlir | 22 ++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 24 +++++++++ 7 files changed, 146 insertions(+) create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp create mode 100644 mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 47f96112a9433..e0f744841db2b 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns( void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [FlattenToElements] +void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 9852df6970fdc..0b44ca7ceee42 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); populateVectorFromElementsLoweringPatterns(patterns); + populateVectorToElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index acbf2b746037b..d74007f13a95b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorScan.cpp LowerVectorShapeCast.cpp LowerVectorStep.cpp + LowerVectorToElements.cpp LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp new file mode 100644 index 0000000000000..014034b8f9737 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -0,0 +1,52 @@ +//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.to_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-to-elements" + +using namespace mlir; + +namespace { + +/// Flattens 2 or more dimensional `vector.to_elements` ops by +/// `vector.shape_cast` + `vector.to_elements`. +struct FlattenToElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + VectorType vecType = op.getSource().getType(); + if (vecType.getRank() <= 1) + return rewriter.notifyMatchFailure( + op, "the rank is already less than or equal to 1"); + if (vecType.getNumScalableDims() > 0) + return rewriter.notifyMatchFailure( + op, "scalable vector is not yet supported"); + auto vec1DType = + VectorType::get({vecType.getNumElements()}, vecType.getElementType()); + Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), + vec1DType, op.getSource()); + rewriter.replaceOpWithNewOp(op, op.getResultTypes(), + shapeCast); + return success(); + } +}; + +} // namespace + +void mlir::vector::populateVectorToElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 07d335117de01..bf4b05f7874de 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1774,3 +1774,43 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> return %0 : vector<2x1x2xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// vector.to_elements +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32> +// CHECK: return %[[V0]], %[[V1]] +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// NOTE: We flatten multi-dimensional to_elements ops with pattern +// `FlattenToElements` and then convert the 1-D to_elements ops to llvm. + +// CHECK-LABEL: func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32> +// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32> +// CHECK: %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64 +// CHECK: %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32> +// CHECK: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir new file mode 100644 index 0000000000000..a57521c4db467 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] +// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index bb1598ee3efe5..560a1331bdaf0 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements } }; +struct TestFlattenVectorToElements + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements) + + StringRef getArgument() const final { + return "test-flatten-vector-to-elements"; + } + StringRef getDescription() const final { + return "Test flattening patterns for to_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToElementsLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper> { @@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); From edf4019862d3448a31bd2d4052fc9d11259a7e37 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 09:48:27 -0700 Subject: [PATCH 02/21] Add new populate patterns for flattening. --- .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 6 ++++++ .../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 5 +++++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index e0f744841db2b..c39c9d4ae00c9 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -317,6 +317,12 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [FlattenToElements] +void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index 014034b8f9737..33c5d2cb33369 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -50,3 +50,8 @@ void mlir::vector::populateVectorToElementsLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } + +void mlir::vector::populateVectorToElementsFlatteningPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 560a1331bdaf0..01a00509c7331 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -825,7 +825,7 @@ struct TestFlattenVectorToElements void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateVectorToElementsLoweringPatterns(patterns); + populateVectorToElementsFlatteningPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; From 12fd6fc77b44f25020aee2b0193a02117d2fc1e1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 08:49:00 -0700 Subject: [PATCH 03/21] [mlir][vector] Add function to unroll vectors. Extract n vector from vector This patch adds a utility function that will unroll vector values. This is different from the current utility function that focuses on unrolling vector operations. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 3 +++ mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index ace26990601c8..95f2ee5a7ac1d 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -255,6 +255,9 @@ using UnrollVectorOpFn = LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn); +LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter, + SmallVector &values); + } // namespace vector /// Constructs a permutation map of invariant memref indices to vector diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 841e1384e03b3..cbedd9563fc29 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef shape, return success(); } +LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter, + SmallVector &subvectors) { + assert(isa(vector.getType()) && "expected vector type"); + VectorType ty = cast(vector.getType()); + Location loc = vector.getLoc(); + if (ty.getRank() < 2) + return rewriter.notifyMatchFailure(loc, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (ty.getScalableDims().front()) + return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim"); + + // We just need zero indices for the all dimensions except the leading one. + for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) { + subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i)); + } + + return success(); +} + LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, vector::UnrollVectorOpFn unrollFn) { assert(op->getNumResults() == 1 && "expected single result"); From 0415e3b066c258e2b12aba4e074f62c07b1ee3d1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 09:39:24 -0700 Subject: [PATCH 04/21] [mlir][vector] Add vector.to_elements unrolling. --- .../Vector/Transforms/LoweringPatterns.h | 2 +- .../Transforms/LowerVectorToElements.cpp | 28 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index c39c9d4ae00c9..31150a2afc19f 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -313,7 +313,7 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, /// Populate the pattern set with the following patterns: /// -/// [FlattenToElements] +/// [UnrollToElements] void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index 33c5d2cb33369..b86e8b274770f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -20,6 +20,32 @@ using namespace mlir; namespace { +struct UnrollToElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + SmallVector vectors; + LogicalResult match = + mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors); + if (failed(match)) { + return match; + } + + // May be large vector. + std::vector results; + for (const auto &vector : vectors) { + // we need to replace the current result + auto subElements = + rewriter.create(op.getLoc(), vector); + results.insert(results.end(), subElements.getResults().begin(), + subElements.getResults().end()); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + /// Flattens 2 or more dimensional `vector.to_elements` ops by /// `vector.shape_cast` + `vector.to_elements`. struct FlattenToElements : OpRewritePattern { @@ -48,7 +74,7 @@ struct FlattenToElements : OpRewritePattern { void mlir::vector::populateVectorToElementsLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateVectorToElementsFlatteningPatterns( From 1b21117c35c588486aa737bdc7f5a37f5c52c1fb Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 10:02:39 -0700 Subject: [PATCH 05/21] Split tests for unrolling and flattening to elements --- .../Vector/vector-to-elements-flattening.mlir | 22 +++++++++++++++++ .../Vector/vector-to-elements-lowering.mlir | 10 ++++---- .../Dialect/Vector/TestVectorTransforms.cpp | 24 +++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir new file mode 100644 index 0000000000000..a57521c4db467 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] +// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir index a57521c4db467..e302dbd174322 100644 --- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s // CHECK-LABEL: func.func @to_elements_1d( // CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> @@ -13,9 +13,11 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { // CHECK-LABEL: func.func @to_elements_2d( // CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] -// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 +// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32> +// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32> +// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1 func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { %0:4 = vector.to_elements %arg0 : vector<2x2xf32> return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 01a00509c7331..093134c119cea 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -830,6 +830,28 @@ struct TestFlattenVectorToElements } }; +struct TestUnrollVectorToElements + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements) + + StringRef getArgument() const final { + return "test-unroll-vector-to-elements"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for to_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToElementsLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper> { @@ -1105,6 +1127,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); From 567af4bde157e7c266dcf0df5be645d370a90086 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 10:09:47 -0700 Subject: [PATCH 06/21] Fix test --- .../VectorToLLVM/vector-to-llvm.mlir | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index bf4b05f7874de..2d33888854ea7 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1795,21 +1795,23 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { // ----- -// NOTE: We flatten multi-dimensional to_elements ops with pattern -// `FlattenToElements` and then convert the 1-D to_elements ops to llvm. +// NOTE: We unroll multi-dimensional to_elements ops with pattern +// `UnrollToElements` and then convert the 1-D to_elements ops to llvm. // CHECK-LABEL: func @to_elements_2d( // CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>> // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK: %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32> +// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32> // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK: %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32> -// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64 -// CHECK: %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32> -// CHECK: %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64 -// CHECK: %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32> -// CHECK: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] +// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32> +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { %0:4 = vector.to_elements %arg0 : vector<2x2xf32> return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 From c58b5c41bcb66ea43547f87ad811bb57fa0b152e Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 12:25:12 -0700 Subject: [PATCH 07/21] Address review comments --- .../Transforms/LowerVectorToElements.cpp | 25 +++++++++---------- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 2 -- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index b86e8b274770f..b897b15d7d690 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -20,26 +20,25 @@ using namespace mlir; namespace { -struct UnrollToElements : OpRewritePattern { +struct UnrollToElements final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ToElementsOp op, PatternRewriter &rewriter) const override { SmallVector vectors; - LogicalResult match = - mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors); - if (failed(match)) { + if (LogicalResult match = + vector::unrollVectorValue(op.getSource(), rewriter, vectors); + failed(match)) { return match; } - // May be large vector. - std::vector results; - for (const auto &vector : vectors) { + // May be a large vector. + SmallVector results; + for (const Value &vector : vectors) { // we need to replace the current result auto subElements = rewriter.create(op.getLoc(), vector); - results.insert(results.end(), subElements.getResults().begin(), - subElements.getResults().end()); + llvm::append_range(results, subElements.getResults()); } rewriter.replaceOp(op, results); return success(); @@ -48,7 +47,7 @@ struct UnrollToElements : OpRewritePattern { /// Flattens 2 or more dimensional `vector.to_elements` ops by /// `vector.shape_cast` + `vector.to_elements`. -struct FlattenToElements : OpRewritePattern { +struct FlattenToElements final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ToElementsOp op, @@ -57,9 +56,9 @@ struct FlattenToElements : OpRewritePattern { if (vecType.getRank() <= 1) return rewriter.notifyMatchFailure( op, "the rank is already less than or equal to 1"); - if (vecType.getNumScalableDims() > 0) - return rewriter.notifyMatchFailure( - op, "scalable vector is not yet supported"); + + assert(vecType.getNumScalableDims() == 0 && + "scalable vector is not yet supported"); auto vec1DType = VectorType::get({vecType.getNumElements()}, vecType.getElementType()); Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index cbedd9563fc29..8d67cd7e80382 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -395,7 +395,6 @@ vector::isValidMaskedInputVector(ArrayRef shape, LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter, SmallVector &subvectors) { - assert(isa(vector.getType()) && "expected vector type"); VectorType ty = cast(vector.getType()); Location loc = vector.getLoc(); if (ty.getRank() < 2) @@ -406,7 +405,6 @@ LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter, if (ty.getScalableDims().front()) return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim"); - // We just need zero indices for the all dimensions except the leading one. for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) { subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i)); } From 871f5a5873a765a89983a8d967d458faed6b3f2e Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 12:26:52 -0700 Subject: [PATCH 08/21] Remove unnecessary comment --- mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index b897b15d7d690..43ba1fe885e85 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -35,7 +35,6 @@ struct UnrollToElements final : OpRewritePattern { // May be a large vector. SmallVector results; for (const Value &vector : vectors) { - // we need to replace the current result auto subElements = rewriter.create(op.getLoc(), vector); llvm::append_range(results, subElements.getResults()); From 244eed587f232ae8c67af133887b6e6f3e8e351d Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 5 Sep 2025 13:00:39 -0700 Subject: [PATCH 09/21] Address review comments --- .../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index 43ba1fe885e85..e1bb07287ae9b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -26,10 +26,8 @@ struct UnrollToElements final : OpRewritePattern { LogicalResult matchAndRewrite(vector::ToElementsOp op, PatternRewriter &rewriter) const override { SmallVector vectors; - if (LogicalResult match = - vector::unrollVectorValue(op.getSource(), rewriter, vectors); - failed(match)) { - return match; + if (failed(vector::unrollVectorValue(op.getSource(), rewriter, vectors))) { + return failure(); } // May be a large vector. From 9933387796fba4c5155b718df9f1f0d5e20d47e4 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 9 Sep 2025 12:40:49 -0700 Subject: [PATCH 10/21] Add comment --- mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 95f2ee5a7ac1d..985f90f6ed955 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -255,6 +255,13 @@ using UnrollVectorOpFn = LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn); +/// Generic utility for mapping values of type vector +/// to n values of type vector +/// Follows the following pattern: +/// 1. Check if already 1-D. If so, return failure. +/// 2. Check for scalable dimensions. If so, return failure. +/// 3. Returns the values of n vector.extract operations corresponding +/// to the outermost dimension. LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter, SmallVector &values); From 351086b8983a77343a4716e5b1f237d07ec3e488 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 9 Sep 2025 12:18:09 -0700 Subject: [PATCH 11/21] Disable and move vector flattening --- .../Vector/Transforms/LoweringPatterns.h | 6 ---- .../Transforms/LowerVectorToElements.cpp | 29 ------------------- .../Vector/Transforms/VectorLinearize.cpp | 24 +++++++++++++++ .../Vector/vector-to-elements-flattening.mlir | 22 -------------- .../Dialect/Vector/TestVectorTransforms.cpp | 24 --------------- 5 files changed, 24 insertions(+), 81 deletions(-) delete mode 100644 mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 31150a2afc19f..f56124cb4fb95 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -317,12 +317,6 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populate the pattern set with the following patterns: -/// -/// [FlattenToElements] -void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index e1bb07287ae9b..718b947f13715 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -42,38 +42,9 @@ struct UnrollToElements final : OpRewritePattern { } }; -/// Flattens 2 or more dimensional `vector.to_elements` ops by -/// `vector.shape_cast` + `vector.to_elements`. -struct FlattenToElements final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ToElementsOp op, - PatternRewriter &rewriter) const override { - VectorType vecType = op.getSource().getType(); - if (vecType.getRank() <= 1) - return rewriter.notifyMatchFailure( - op, "the rank is already less than or equal to 1"); - - assert(vecType.getNumScalableDims() == 0 && - "scalable vector is not yet supported"); - auto vec1DType = - VectorType::get({vecType.getNumElements()}, vecType.getElementType()); - Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), - vec1DType, op.getSource()); - rewriter.replaceOpWithNewOp(op, op.getResultTypes(), - shapeCast); - return success(); - } -}; - } // namespace void mlir::vector::populateVectorToElementsLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } - -void mlir::vector::populateVectorToElementsFlatteningPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7dde6311fa809..0f5e9259d4c19 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -798,6 +798,30 @@ struct LinearizeVectorFromElements final } }; +/// Flattens 2 or more dimensional `vector.to_elements` ops by +/// `vector.shape_cast` + `vector.to_elements`. +struct FlattenToElements final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + VectorType vecType = op.getSource().getType(); + if (vecType.getRank() <= 1) + return rewriter.notifyMatchFailure( + op, "the rank is already less than or equal to 1"); + + assert(vecType.getNumScalableDims() == 0 && + "scalable vector is not yet supported"); + auto vec1DType = + VectorType::get({vecType.getNumElements()}, vecType.getElementType()); + Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), + vec1DType, op.getSource()); + rewriter.replaceOpWithNewOp(op, op.getResultTypes(), + shapeCast); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir deleted file mode 100644 index a57521c4db467..0000000000000 --- a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s - -// CHECK-LABEL: func.func @to_elements_1d( -// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> -// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1 -func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { - %0:2 = vector.to_elements %arg0 : vector<2xf32> - return %0#0, %0#1 : f32, f32 -} - -// ----- - -// CHECK-LABEL: func.func @to_elements_2d( -// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] -// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 -func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { - %0:4 = vector.to_elements %arg0 : vector<2x2xf32> - return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 -} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 093134c119cea..d6596cd341df7 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -808,28 +808,6 @@ struct TestUnrollVectorFromElements } }; -struct TestFlattenVectorToElements - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements) - - StringRef getArgument() const final { - return "test-flatten-vector-to-elements"; - } - StringRef getDescription() const final { - return "Test flattening patterns for to_elements ops"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateVectorToElementsFlatteningPatterns(patterns); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestUnrollVectorToElements : public PassWrapper> { @@ -1129,8 +1107,6 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration(); From 9f7d15d1c0ba453db4caa9b86281c3dc5f622bfc Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 9 Sep 2025 12:30:50 -0700 Subject: [PATCH 12/21] Re-enable and rename flattening to linearize --- .../Vector/Transforms/VectorLinearize.cpp | 49 +++++++++++++------ mlir/test/Dialect/Vector/linearize.mlir | 23 +++++++++ 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 0f5e9259d4c19..54eb182a9680f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -798,26 +798,45 @@ struct LinearizeVectorFromElements final } }; -/// Flattens 2 or more dimensional `vector.to_elements` ops by -/// `vector.shape_cast` + `vector.to_elements`. -struct FlattenToElements final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ToElementsOp op, - PatternRewriter &rewriter) const override { - VectorType vecType = op.getSource().getType(); +/// This pattern linearizes the operand in `vector.to_elements` operations +/// by converting the result type to a 1-D vector while preserving all element +/// values. The transformation creates a linearized `vector.shape_cast` +/// followed by a `vector.to_elements`. +/// +/// Example: +/// +/// %0:4 = vector.to_elements %v : vector<2x2xf32> +/// +/// is converted to: +/// +/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32> +/// %0:4 = vector.to_elements %vector_cast : vector<4xf32> +/// +struct LinearizeVectorToElements final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorToElements(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + VectorType vecType = toElementsOp.getSource().getType(); if (vecType.getRank() <= 1) return rewriter.notifyMatchFailure( - op, "the rank is already less than or equal to 1"); + toElementsOp, "the rank is already less than or equal to 1"); assert(vecType.getNumScalableDims() == 0 && "scalable vector is not yet supported"); auto vec1DType = VectorType::get({vecType.getNumElements()}, vecType.getElementType()); - Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), - vec1DType, op.getSource()); - rewriter.replaceOpWithNewOp(op, op.getResultTypes(), - shapeCast); + Value shapeCast = vector::ShapeCastOp::create( + rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource()); + rewriter.replaceOpWithNewOp( + toElementsOp, toElementsOp.getResultTypes(), shapeCast); return success(); } }; @@ -914,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add( - typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 5e8bfd0698b33..fe697c8b9c057 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> return %1 : vector<2x2xf32> } + +// ----- + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] +// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} From a8f9b9cb13b1677bd5be6d662f8865f739671208 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 9 Sep 2025 12:33:21 -0700 Subject: [PATCH 13/21] Remove LinearizeToElements --- .../Vector/Transforms/VectorLinearize.cpp | 47 +------------------ mlir/test/Dialect/Vector/linearize.mlir | 23 --------- 2 files changed, 2 insertions(+), 68 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 54eb182a9680f..7dde6311fa809 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -798,49 +798,6 @@ struct LinearizeVectorFromElements final } }; -/// This pattern linearizes the operand in `vector.to_elements` operations -/// by converting the result type to a 1-D vector while preserving all element -/// values. The transformation creates a linearized `vector.shape_cast` -/// followed by a `vector.to_elements`. -/// -/// Example: -/// -/// %0:4 = vector.to_elements %v : vector<2x2xf32> -/// -/// is converted to: -/// -/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32> -/// %0:4 = vector.to_elements %vector_cast : vector<4xf32> -/// -struct LinearizeVectorToElements final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LinearizeVectorToElements(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - VectorType vecType = toElementsOp.getSource().getType(); - if (vecType.getRank() <= 1) - return rewriter.notifyMatchFailure( - toElementsOp, "the rank is already less than or equal to 1"); - - assert(vecType.getNumScalableDims() == 0 && - "scalable vector is not yet supported"); - auto vec1DType = - VectorType::get({vecType.getNumElements()}, vecType.getElementType()); - Value shapeCast = vector::ShapeCastOp::create( - rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource()); - rewriter.replaceOpWithNewOp( - toElementsOp, toElementsOp.getResultTypes(), shapeCast); - return success(); - } -}; - } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -933,8 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add(typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index fe697c8b9c057..5e8bfd0698b33 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -538,26 +538,3 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> return %1 : vector<2x2xf32> } - -// ----- - -// CHECK-LABEL: func.func @to_elements_1d( -// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> -// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1 -func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { - %0:2 = vector.to_elements %arg0 : vector<2xf32> - return %0#0, %0#1 : f32, f32 -} - -// ----- - -// CHECK-LABEL: func.func @to_elements_2d( -// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] -// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> -// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 -func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { - %0:4 = vector.to_elements %arg0 : vector<2x2xf32> - return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 -} From 0568cbabd9c81e47cea30b0d01864ea0466f4633 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 10 Sep 2025 07:04:00 -0700 Subject: [PATCH 14/21] Improve API of unrollVectorValue. Parameters are now: * using TypedValue instead of just Value * using RewriterBase class. Return types are: * changed to FailureOr> instead of passing a Value as a parameter and returning Logical. --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 14 +++++--------- .../Vector/Transforms/LowerVectorToElements.cpp | 8 ++++++-- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 8 +++++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 985f90f6ed955..97163c4532378 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -255,15 +255,11 @@ using UnrollVectorOpFn = LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn); -/// Generic utility for mapping values of type vector -/// to n values of type vector -/// Follows the following pattern: -/// 1. Check if already 1-D. If so, return failure. -/// 2. Check for scalable dimensions. If so, return failure. -/// 3. Returns the values of n vector.extract operations corresponding -/// to the outermost dimension. -LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter, - SmallVector &values); +/// Generic utility for unrolling values of type vector +/// to N values of type vector using vector.extract. If the input +/// is rank-1 or has leading scalable dimension, failure is returned. +FailureOr> unrollVectorValue(TypedValue, + RewriterBase &); } // namespace vector diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index 718b947f13715..56d0f61d3c9cc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -25,10 +25,14 @@ struct UnrollToElements final : OpRewritePattern { LogicalResult matchAndRewrite(vector::ToElementsOp op, PatternRewriter &rewriter) const override { - SmallVector vectors; - if (failed(vector::unrollVectorValue(op.getSource(), rewriter, vectors))) { + + TypedValue source = op.getSource(); + FailureOr> result = + vector::unrollVectorValue(source, rewriter); + if (failed(result)) { return failure(); } + SmallVector vectors = *result; // May be a large vector. SmallVector results; diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 8d67cd7e80382..d8e96a294005b 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -393,8 +393,10 @@ vector::isValidMaskedInputVector(ArrayRef shape, return success(); } -LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter, - SmallVector &subvectors) { +FailureOr> +vector::unrollVectorValue(TypedValue vector, + RewriterBase &rewriter) { + SmallVector subvectors; VectorType ty = cast(vector.getType()); Location loc = vector.getLoc(); if (ty.getRank() < 2) @@ -409,7 +411,7 @@ LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter, subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i)); } - return success(); + return subvectors; } LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, From 00cc94bf6e01f703cb2a3ea941406039d39d8f56 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 10 Sep 2025 11:43:00 -0700 Subject: [PATCH 15/21] Documentation --- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index d8e96a294005b..56f20c334c50d 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -393,6 +393,20 @@ vector::isValidMaskedInputVector(ArrayRef shape, return success(); } +/// Takes a 2+ dimensional vector as an input +/// returns n vector values produced by n vector.extract operations. +/// I.e. calling unrollVectorValue([[%v]], rewriter) such that +/// +/// %v : vector +/// +/// will produce the following IR changes +/// +/// %v0 = vector.extract %v[0] : vector +/// %v1 = vector.extract %v[1] : vector +/// ... +/// %vnminusone = vector.extract %v[n-1] : vector +/// +/// and returns SmallVector r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]} FailureOr> vector::unrollVectorValue(TypedValue vector, RewriterBase &rewriter) { From 6894af06a787c03f7a72b9d62025d373170e3a88 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 10 Sep 2025 11:49:47 -0700 Subject: [PATCH 16/21] Add transform.apply_patterns.vector.unroll_to_elements --- .../Dialect/Vector/TransformOps/VectorTransformOps.td | 11 +++++++++++ .../Vector/TransformOps/VectorTransformOps.cpp | 5 +++++ mlir/test/python/dialects/transform_vector_ext.py | 2 ++ 3 files changed, 18 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 07a4117a37b2c..72a69a056c46e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op]> { + let description = [{ + Indicates that vector to_elements operations should be unrolled + along the outermost dimension. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerScanPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index fe066dc04ad55..6bb390aa09d3e 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( vector::populateVectorFromElementsLoweringPatterns(patterns); } +void transform::ApplyUnrollToElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorToElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index 5a648fe073315..28902b012f7cb 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -48,6 +48,8 @@ def non_configurable_patterns(): vector.ApplyLowerGatherPatternsOp() # CHECK: transform.apply_patterns.vector.unroll_from_elements vector.ApplyUnrollFromElementsPatternsOp() + # CHECK: transform.apply_patterns.vector.unroll_to_elements + vector.ApplyUnrollToElementsPatternsOp() # CHECK: transform.apply_patterns.vector.lower_scan vector.ApplyLowerScanPatternsOp() # CHECK: transform.apply_patterns.vector.lower_shape_cast From 82004a21be08227d5b2b7656813646d74af154c0 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 10 Sep 2025 11:58:16 -0700 Subject: [PATCH 17/21] Minor changes --- .../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 4 ++-- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index 56d0f61d3c9cc..82a4fab138191 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -20,7 +20,7 @@ using namespace mlir; namespace { -struct UnrollToElements final : OpRewritePattern { +struct UnrollToElements final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ToElementsOp op, @@ -38,7 +38,7 @@ struct UnrollToElements final : OpRewritePattern { SmallVector results; for (const Value &vector : vectors) { auto subElements = - rewriter.create(op.getLoc(), vector); + vector::ToElementsOp::create(rewriter, op.getLoc(), vector); llvm::append_range(results, subElements.getResults()); } rewriter.replaceOp(op, results); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 56f20c334c50d..39dc7a4f284a6 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -401,10 +401,10 @@ vector::isValidMaskedInputVector(ArrayRef shape, /// /// will produce the following IR changes /// -/// %v0 = vector.extract %v[0] : vector -/// %v1 = vector.extract %v[1] : vector +/// %v0 = vector.extract %v[0] : vector from vector +/// %v1 = vector.extract %v[1] : vector from vector /// ... -/// %vnminusone = vector.extract %v[n-1] : vector +/// %vnminusone = vector.extract %v[n-1] : vector from ... /// /// and returns SmallVector r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]} FailureOr> From cb6cf99c8b687b6493d833bf55ada9764f9137a4 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 11 Sep 2025 10:05:38 -0700 Subject: [PATCH 18/21] Remove comment and use inline storage for SmallVector --- mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp index 82a4fab138191..a53a183ec31bc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -34,8 +34,7 @@ struct UnrollToElements final : public OpRewritePattern { } SmallVector vectors = *result; - // May be a large vector. - SmallVector results; + SmallVector results; for (const Value &vector : vectors) { auto subElements = vector::ToElementsOp::create(rewriter, op.getLoc(), vector); From db83d2f4af38a3fbfdb29a3dfff284edd712ccd3 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 11 Sep 2025 10:21:41 -0700 Subject: [PATCH 19/21] Add test with transform interpreter --- .../Vector/vector-to-elements-lowering.mlir | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir index e302dbd174322..18bcf7da7959a 100644 --- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s // CHECK-LABEL: func.func @to_elements_1d( // CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> @@ -9,6 +10,18 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { return %0#0, %0#1 : f32, f32 } +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.unroll_to_elements + } : !transform.any_op + transform.yield + } +} + // ----- // CHECK-LABEL: func.func @to_elements_2d( @@ -22,3 +35,15 @@ func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { %0:4 = vector.to_elements %arg0 : vector<2x2xf32> return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.unroll_to_elements + } : !transform.any_op + transform.yield + } +} From ccec33c15fc18ab41cd96a93a23ef09cac85c5e0 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 11 Sep 2025 10:28:14 -0700 Subject: [PATCH 20/21] Use transform dialect library file --- mlir/test/Dialect/Vector/lit.local.cfg | 2 ++ .../Dialect/Vector/td/unroll-elements.mlir | 11 ++++++++ .../Vector/vector-to-elements-lowering.mlir | 27 ++----------------- 3 files changed, 15 insertions(+), 25 deletions(-) create mode 100644 mlir/test/Dialect/Vector/lit.local.cfg create mode 100644 mlir/test/Dialect/Vector/td/unroll-elements.mlir diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg new file mode 100644 index 0000000000000..62743008a3e3a --- /dev/null +++ b/mlir/test/Dialect/Vector/lit.local.cfg @@ -0,0 +1,2 @@ +# Skip the directory with input TD sequences +config.excludes = ["td"] diff --git a/mlir/test/Dialect/Vector/td/unroll-elements.mlir b/mlir/test/Dialect/Vector/td/unroll-elements.mlir new file mode 100644 index 0000000000000..40a90a33b0ac4 --- /dev/null +++ b/mlir/test/Dialect/Vector/td/unroll-elements.mlir @@ -0,0 +1,11 @@ +module attributes {transform.with_named_sequence} { + transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.unroll_to_elements + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir index 18bcf7da7959a..9ec0d76599c41 100644 --- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s -// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \ +// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s // CHECK-LABEL: func.func @to_elements_1d( // CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> @@ -10,18 +11,6 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { return %0#0, %0#1 : f32, f32 } -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %f = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %f { - transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.unroll_to_elements - } : !transform.any_op - transform.yield - } -} - // ----- // CHECK-LABEL: func.func @to_elements_2d( @@ -35,15 +24,3 @@ func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { %0:4 = vector.to_elements %arg0 : vector<2x2xf32> return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 } - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %f = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %f { - transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.unroll_to_elements - } : !transform.any_op - transform.yield - } -} From 7e52d00fe71355b1d986e5ada996960fbd573e26 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Thu, 11 Sep 2025 13:40:35 -0400 Subject: [PATCH 21/21] Update mlir/test/Dialect/Vector/lit.local.cfg Co-authored-by: Jakub Kuderski --- mlir/test/Dialect/Vector/lit.local.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg index 62743008a3e3a..3e9e8f8497624 100644 --- a/mlir/test/Dialect/Vector/lit.local.cfg +++ b/mlir/test/Dialect/Vector/lit.local.cfg @@ -1,2 +1,2 @@ -# Skip the directory with input TD sequences +# Skip the directory with input TD sequences. config.excludes = ["td"]