From 8ee9261d10ea50c1b0606d6ba7ba0232cc525dc3 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 25 Jun 2025 18:57:15 +0000 Subject: [PATCH 1/2] [mlir][Vector] Add `vector.to_elements` lowering to LLVM Only elements with at least one use are lowered to `llvm.extractelement` op. --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 34 +++++++++++++++++- .../vector-to-llvm-interface.mlir | 36 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d53d11f87efe8..f1543200fb56f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering } }; +/// Conversion pattern for a `vector.to_elements`. +struct VectorToElementsLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = toElementsOp.getLoc(); + auto idxType = typeConverter->convertType(rewriter.getIndexType()); + Value source = adaptor.getSource(); + + SmallVector results(toElementsOp->getNumResults()); + for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) { + // Create an extractelement operation only for results that are not dead. + if (!element.use_empty()) { + auto constIdx = rewriter.create( + loc, idxType, rewriter.getIntegerAttr(idxType, idx)); + auto llvmType = typeConverter->convertType(element.getType()); + + Value result = rewriter.create( + loc, llvmType, source, constIdx); + results[idx] = result; + } + } + + rewriter.replaceOp(toElementsOp, results); + return success(); + } +}; + /// Conversion pattern for vector.step. struct VectorScalableStepOpLowering : public ConvertOpToLLVMPattern { @@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, VectorDeinterleaveOpLowering, VectorFromElementsLowering, - VectorScalableStepOpLowering>(converter); + VectorToElementsLowering, VectorScalableStepOpLowering>( + converter); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 3df14528bac39..8f73e79d7bfc2 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) { // CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector // CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector to vector<1xf32> -// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64 // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64 @@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector { // ----- +// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements + // CHECK-SAME: %[[A:.*]]: vector<4xf32>) + // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32> + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> + // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64 + // CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32> + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> + // CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32 +func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %a : vector<4xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @vector_to_elements_dead_elements + // CHECK-SAME: %[[A:.*]]: vector<4xf32>) + // CHECK-NOT: llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> + // CHECK-NOT: llvm.mlir.constant(2 : i64) : i64 + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> + // CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32 +func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { + %0:4 = vector.to_elements %a : vector<4xf32> + return %0#1, %0#3 : f32, f32 +} + +// ----- + //===----------------------------------------------------------------------===// // vector.step //===----------------------------------------------------------------------===// From cc50307de44b2e700436471857c11c132e7c89d4 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Thu, 26 Jun 2025 17:25:32 +0000 Subject: [PATCH 2/2] Feedback --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 19 ++++++++++--------- .../vector-to-llvm-interface.mlir | 12 ++++++++---- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f1543200fb56f..501d98862672d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -2000,15 +2000,16 @@ struct VectorToElementsLowering SmallVector results(toElementsOp->getNumResults()); for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) { // Create an extractelement operation only for results that are not dead. - if (!element.use_empty()) { - auto constIdx = rewriter.create( - loc, idxType, rewriter.getIntegerAttr(idxType, idx)); - auto llvmType = typeConverter->convertType(element.getType()); - - Value result = rewriter.create( - loc, llvmType, source, constIdx); - results[idx] = result; - } + if (element.use_empty()) + continue; + + auto constIdx = rewriter.create( + loc, idxType, rewriter.getIntegerAttr(idxType, idx)); + auto llvmType = typeConverter->convertType(element.getType()); + + Value result = rewriter.create(loc, llvmType, + source, constIdx); + results[idx] = result; } rewriter.replaceOp(toElementsOp, results); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 8f73e79d7bfc2..c03d67fdc33fa 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2421,7 +2421,11 @@ func.func @from_elements_0d(%arg0: f32) -> vector { // ----- -// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements +//===----------------------------------------------------------------------===// +// vector.to_elements +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @to_elements_no_dead_elements // CHECK-SAME: %[[A:.*]]: vector<4xf32>) // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32> @@ -2432,14 +2436,14 @@ func.func @from_elements_0d(%arg0: f32) -> vector { // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> // CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32 -func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { +func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { %0:4 = vector.to_elements %a : vector<4xf32> return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 } // ----- -// CHECK-LABEL: func.func @vector_to_elements_dead_elements +// CHECK-LABEL: func.func @to_elements_dead_elements // CHECK-SAME: %[[A:.*]]: vector<4xf32>) // CHECK-NOT: llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 @@ -2448,7 +2452,7 @@ func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> // CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32 -func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { +func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { %0:4 = vector.to_elements %a : vector<4xf32> return %0#1, %0#3 : f32, f32 }