diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 5e8421ed67d66..8353314ed958b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp : def Vector_BroadcastOp : Vector_Op<"broadcast", [Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index fcfb401fd9867..3179b4f975404 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2522,6 +2522,10 @@ void BroadcastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional> BroadcastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. static llvm::SetVector diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 1cc477d9dca91..fc443ab0d138e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +struct UnrollBroadcastPattern : public OpRewritePattern { + UnrollBroadcastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, broadcastOp); + if (!targetShape) + return failure(); + + Location loc = broadcastOp.getLoc(); + VectorType srcType = dyn_cast(broadcastOp.getSourceType()); + VectorType resType = broadcastOp.getResultVectorType(); + VectorType targetType = + resType.cloneWith(*targetShape, resType.getElementType()); + Value result = rewriter.create( + loc, resType, rewriter.getZeroAttr(resType)); + + SmallVector originalShape = *broadcastOp.getShapeForUnroll(); + SmallVector strides(originalShape.size(), 1); + + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, *targetShape)) { + Value newSrc; + if (!srcType) { + // Scalar to vector broadcast. + newSrc = broadcastOp.getSource(); + } else { + // Vector to vector broadcast. + int64_t rank = srcType.getRank(); + SmallVector srcOffsets(offsets.end() - rank, offsets.end()); + SmallVector srcShape(targetShape->end() - rank, + targetShape->end()); + SmallVector srcStrides(strides.end() - rank, strides.end()); + // adjust the offset and shape for src if the corresponding dim is 1. + for (int64_t i = 0; i < rank; ++i) { + if (srcType.getDimSize(i) == 1) { + srcOffsets[i] = 0; + srcShape[i] = 1; + } + } + newSrc = rewriter.createOrFold( + loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides); + } + + Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp, + newSrc, targetType); + + result = rewriter.createOrFold( + loc, newOp->getResult(0), result, offsets, strides); + } + + rewriter.replaceOp(broadcastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), options, benefit); + patterns + .add( + patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 9c158d05b723c..fbb178fb49d87 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{ // CHECK-LABEL: func @negative_vector_fma_3d // CHECK-NOT: vector.extract_strided_slice // CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32> -// CHECK: return +// CHECK: return func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> { %0 = vector.multi_reduction #vector.kind, %v, %acc [1] : vector<4x6xf32> to vector<4xf32> @@ -311,3 +311,70 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf // BATCHED-COUNT-16: vector.contract // BATCHED-NOT: vector.contract // BATCHED: return + + +func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> { + %0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @vector_broadcast +// CHECK-SAME: [[arg0:%.+]]: vector<4xf32> +// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32> +// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32> +// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32> +// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32> +// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: return [[r3]] : vector<4x4xf32> + +func.func @vector_broadcast_with_leading_unit_dim(%v: vector<1x4xf32>) -> vector<4x4xf32> { + %0 = vector.broadcast %v : vector<1x4xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func.func @vector_broadcast_with_leading_unit_dim +// CHECK-SAME: ([[arg0:%.+]]: vector<1x4xf32>) -> vector<4x4xf32> { +// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32> +// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<1x2xf32> to vector<2x2xf32> +// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32> +// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<1x2xf32> to vector<2x2xf32> +// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32> +// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<1x2xf32> to vector<2x2xf32> +// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32> +// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<1x2xf32> to vector<2x2xf32> +// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: return [[r3]] : vector<4x4xf32> + +func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector<4x4xf32> { + %0 = vector.broadcast %v : vector<4x1xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func.func @vector_broadcast_with_tailing_unit_dim +// CHECK-SAME: ([[arg0:%.+]]: vector<4x1xf32>) -> vector<4x4xf32> { +// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32> +// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2x1xf32> to vector<2x2xf32> +// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32> +// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2x1xf32> to vector<2x2xf32> +// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32> +// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2x1xf32> to vector<2x2xf32> +// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32> +// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32> +// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: return [[r3]] : vector<4x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f4f32e9339870..54aa96ba89a00 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); populateVectorUnrollPatterns( - patterns, UnrollVectorOptions() - .setNativeShape(ArrayRef{2, 2}) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + patterns, + UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint([](Operation *op) { + return success( + isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{2})