diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index c973eca0132a9..f46aa0428f12f 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -453,4 +453,27 @@ def ApplyVectorReductionToContractPatternsOp : Op]> { + let description = [{ + Patterns that remove redundant Vector Ops by re-ordering them with + e.g. elementwise Ops: + ``` + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.addf %at, %bt : vector<2x4xf32> + ``` + gets converted to: + ``` + %0 = arith.addf %a, %b : vector<4x2xf32> + %r = vector.transpose %0, [1, 0] : vector<2x4xf32> + ``` + At the moment, these patterns are limited to vector.broadcast and + vector.transpose. + }]; + + let assemblyFormat = "attr-dict"; +} + #endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 20c577273d786..12dcf768dd928 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -67,6 +67,9 @@ void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns( void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorReductionToContractPatterns(patterns); + + // TODO: As we now have a dedicated transform for + // `populateSinkVectorOpsPatterns` we can remove it from here. vector::populateSinkVectorOpsPatterns(patterns); } @@ -204,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns( populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); } +void transform::ApplySinkVectorPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateSinkVectorOpsPatterns(patterns); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index dc46ed17a374d..b6fac80d871e6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1043,6 +1043,66 @@ struct ReorderElementwiseOpsOnBroadcast final } }; +/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp). +/// This may result in cleaner code when extracting a single value +/// from multi-element vector and also to help canonicalize 1-element vectors to +/// scalars. +/// ``` +/// %0 = arith.addf %arg0, %arg1 : vector<4xf32> +/// %1 = vector.extract %0[1] : f32 from vector<4xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32> +/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32> +/// %2 = arith.addf %0, %1 : f32 +/// ``` +class ExtractOpFromElementwise final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp op, + PatternRewriter &rewriter) const override { + Operation *eltwise = op.getVector().getDefiningOp(); + + // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be, + // as it doesn't support scalars. + if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) || + isa(eltwise)) + return rewriter.notifyMatchFailure(op, "not an elementwise op"); + + if (eltwise->getNumResults() != 1) + return rewriter.notifyMatchFailure(op, "expected single result"); + + if (!eltwise->hasOneUse()) + return rewriter.notifyMatchFailure(op, "expected single op use"); + + if (!llvm::all_equal(eltwise->getOperandTypes())) + return rewriter.notifyMatchFailure(op, "operand types are different"); + + Type dstType = op.getType(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(eltwise); + + IRMapping mapping; + Location loc = eltwise->getLoc(); + SmallVector pos = op.getMixedPosition(); + for (Value arg : eltwise->getOperands()) { + Value newArg = rewriter.create(loc, arg, pos); + mapping.map(arg, newArg); + } + + Operation *newEltwise = rewriter.clone(*eltwise, mapping); + newEltwise->getResult(0).setType(dstType); + + rewriter.replaceOp(op, newEltwise); + rewriter.eraseOp(eltwise); + return success(); + } +}; + // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // @@ -2111,8 +2171,8 @@ void mlir::vector:: void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), - benefit); + ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>( + patterns.getContext(), benefit); } void mlir::vector::populateChainedVectorReductionFoldingPatterns( diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index cd83e1239fdda..375fa37bd84b0 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -59,24 +59,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, -// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 79 : index -// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index -// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex> -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex> -// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex> -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex> - -// CHECK: %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex> - -// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> -// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> -// CHECK: return %[[VAL_21]] : tensor<1x4xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor<45x80x16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, +// CHECK-SAME: %[[ARG5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index +// CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index +// CHECK: %[[ADD2:.*]] = arith.addi %[[ARG3]], %[[ARG4]] : index + +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ADD1]], %[[C79]], %[[ADD2]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: return %[[WRITE]] : tensor<1x4xf32> // CHECK: } // ----- @@ -98,19 +93,17 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8 } // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: index, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 79 : index -// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> -// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex> -// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex> -// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> -// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> -// CHECK: return %[[VAL_12]] : tensor<1x4xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: index, +// CHECK-SAME: %[[ARG2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { + +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index + +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[C79]], %[[ARG1]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG2]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: return %[[WRITE]] : tensor<1x4xf32> // CHECK: } // ----- diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir new file mode 100644 index 0000000000000..ef17b69b2444c --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s + +// This is smoke test for `transform.apply_patterns.vector.sink_ops` and this +// file is also used in `vector-sink.mlir`. +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.sink_ops + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index 7ce840575a803..8c8f1797aaab6 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s //----------------------------------------------------------------------------- // [Pattern: ReorderElementwiseOpsOnBroadcast] @@ -423,3 +424,92 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, % %r = arith.addf %at, %bt : vector<6x[4]x2x3xf32> return %r : vector<6x[4]x2x3xf32> } + +// ----- + +//----------------------------------------------------------------------------- +// [Pattern: ExtractOpFromElementwise] +//----------------------------------------------------------------------------- + +// CHECK-LABEL: @extract_elementwise_scalar +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>) +func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 { +// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32> +// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32> +// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32 +// CHECK: return %[[RES]] : f32 + %0 = arith.addf %arg0, %arg1 : vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +} + +// CHECK-LABEL: @extract_elementwise_arg_res_different_types +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>) +func.func @extract_elementwise_arg_res_different_types(%arg0: vector<4xindex>) -> i64 { +// CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex> +// CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64 +// CHECK: return %[[RES]] : i64 + %0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64> + %1 = vector.extract %0[1] : i64 from vector<4xi64> + return %1 : i64 +} + +// CHECK-LABEL: @extract_elementwise_vec +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>) +func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> { +// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32> +// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32> +// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + %0 = arith.addf %arg0, %arg1 : vector<2x4xf32> + %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @negative_extract_elementwise_no_single_use +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>) +func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) { +// Do not propagate extract, as elementwise has other uses. +// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32> +// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32> +// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32> + %0 = arith.addf %arg0, %arg1 : vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1, %0 : f32, vector<4xf32> +} + +// CHECK-LABEL: @negative_extract_elementwise_not_one_res +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>) +func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { +// Do not propagate extract, as elementwise has more than 1 result. +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32> +// CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32> +// CHECK: return %[[EXT]] : i32 + %low, %hi = arith.mulsi_extended %arg0, %arg1 : vector<4xi32> + %1 = vector.extract %low[1] : i32 from vector<4xi32> + return %1 : i32 +} + +// CHECK-LABEL: @negative_extract_not_elementwise +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>) +func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 { +// `test.increment` is not an elemewise op. +// CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64> +// CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64> +// CHECK: return %[[RES]] : i64 + %0 = test.increment %arg0: vector<4xi64> + %1 = vector.extract %0[1] : i64 from vector<4xi64> + return %1 : i64 +} + +// CHECK-LABEL: @negative_extract_vec_fma +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xf32>) +func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> f32 { +// `vector.fma` doesn't suppport scalars. +// CHECK: %[[FMA:.*]] = vector.fma %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<4xf32> +// CHECK: %[[RES:.*]] = vector.extract %[[FMA]][1] : f32 from vector<4xf32> +// CHECK: return %[[RES]] : f32 + %0 = vector.fma %arg0, %arg1, %arg2: vector<4xf32> + %1 = vector.extract %0[1] : f32 from vector<4xf32> + return %1 : f32 +}