diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 3d963dea2f572..359590f2434dc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -198,14 +198,14 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, case vector::CombiningKind::ADD: case vector::CombiningKind::XOR: // Initialize reduction vector to: | 0 | .. | 0 | r | - return rewriter.create( - loc, r, constantZero(rewriter, loc, vtp), - constantIndex(rewriter, loc, 0)); + return rewriter.create(loc, r, + constantZero(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); case vector::CombiningKind::MUL: // Initialize reduction vector to: | 1 | .. | 1 | r | - return rewriter.create( - loc, r, constantOne(rewriter, loc, vtp), - constantIndex(rewriter, loc, 0)); + return rewriter.create(loc, r, + constantOne(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); case vector::CombiningKind::AND: case vector::CombiningKind::OR: // Initialize reduction vector to: | r | .. | r | r | @@ -628,31 +628,49 @@ struct ForOpRewriter : public OpRewritePattern { const VL vl; }; +static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op, + Value inp) { + if (auto redOp = inp.getDefiningOp()) { + if (auto forOp = redOp.getVector().getDefiningOp()) { + if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { + rewriter.replaceOp(op, redOp.getVector()); + return success(); + } + } + } + return failure(); +} + /// Reduction chain cleanup. /// v = for { } -/// s = vsum(v) v = for { } -/// u = expand(s) -> for (v) { } +/// s = vsum(v) v = for { } +/// u = broadcast(s) -> for (v) { } /// for (u) { } -template -struct ReducChainRewriter : public OpRewritePattern { +struct ReducChainBroadcastRewriter + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(VectorOp op, + LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { - Value inp = op.getSource(); - if (auto redOp = inp.getDefiningOp()) { - if (auto forOp = redOp.getVector().getDefiningOp()) { - if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { - rewriter.replaceOp(op, redOp.getVector()); - return success(); - } - } - } - return failure(); + return cleanReducChain(rewriter, op, op.getSource()); } }; +/// Reduction chain cleanup. +/// v = for { } +/// s = vsum(v) v = for { } +/// u = insert(s) -> for (v) { } +/// for (u) { } +struct ReducChainInsertRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::InsertOp op, + PatternRewriter &rewriter) const override { + return cleanReducChain(rewriter, op, op.getValueToStore()); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -668,6 +686,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, vector::populateVectorStepLoweringPatterns(patterns); patterns.add(patterns.getContext(), vectorLength, enableVLAVectorization, enableSIMDIndex32); - patterns.add, - ReducChainRewriter>(patterns.getContext()); + patterns.add( + patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir index 2475aa5139da4..b2dfbeb53fde8 100755 --- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir @@ -22,7 +22,7 @@ // CHECK-NOVEC: } // // CHECK-VEC-LABEL: func.func @sum_reduction -// CHECK-VEC: vector.insertelement +// CHECK-VEC: vector.insert // CHECK-VEC: scf.for // CHECK-VEC: vector.create_mask // CHECK-VEC: vector.maskedload diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir index 364ba6e71ff3b..64235c7227800 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, // CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index // CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> // CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> +// CHECK-VEC16: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32> // CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { // CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> // CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> @@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, // CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index // CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> // CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> +// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32> // CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { // CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> // CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> @@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, // CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref // CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale // CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index -// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32> +// CHECK-VEC4-SVE: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32> // CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) { // CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]] // CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir index f4b565c7f9c8a..0ab72897d7bc3 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -82,7 +82,7 @@ // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index // CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64 // CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64> +// CHECK: %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64> // CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) { // CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]] // CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1> diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir index 01b717090e87a..6effbbf98abb7 100644 --- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir +++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir @@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor, // CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref // CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref // CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32> +// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32> // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> @@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor, // CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref // CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref // CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32> +// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32> // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> @@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor, // CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref // CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref // CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32> +// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32> // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> @@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor, // CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref // CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref // CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32> +// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32> // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> @@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor, // CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref // CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref // CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32> +// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32> // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>