Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :

def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2401,6 +2401,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
Expand Down
74 changes: 69 additions & 5 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};

struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
UnrollBroadcastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::BroadcastOp>(context, benefit),
options(options) {}

LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, broadcastOp);
if (!targetShape)
return failure();

Location loc = broadcastOp.getLoc();
VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
VectorType resType = broadcastOp.getResultVectorType();
VectorType newType =
resType.cloneWith(*targetShape, resType.getElementType());
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));

SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
SmallVector<int64_t> strides(originalShape.size(), 1);

for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value newSrc;
if (!srcType) {
// Scalar to vector broadcast.
newSrc = broadcastOp.getSource();
} else {
// Vector to vector broadcast.
int64_t rank = srcType.getRank();
SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
SmallVector<int64_t> srcShape(targetShape->end() - rank,
targetShape->end());
SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
// addjust the offset and shape for src if the corresponding dim is 1.
for (int64_t i = 0; i < rank; ++i) {
if (srcType.getDimSize(i) == 1) {
srcOffsets[i] = 0;
srcShape[i] = 1;
}
}
newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
}

Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
newSrc, newType);

result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}

rewriter.replaceOp(broadcastOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern>(
patterns.getContext(), options, benefit);
patterns
.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
patterns.getContext(), options, benefit);
}
25 changes: 24 additions & 1 deletion mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
// CHECK: return
// CHECK: return

func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
Expand Down Expand Up @@ -311,3 +311,26 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
// BATCHED-COUNT-16: vector.contract
// BATCHED-NOT: vector.contract
// BATCHED: return


func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
%0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// CHECK-LABEL: func @vector_broadcast
// CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: return [[r3]]
14 changes: 8 additions & 6 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<arith::AddFOp, vector::FMAOp,
vector::MultiDimReductionOp>(op));
}));
patterns,
UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
vector::BroadcastOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2})
Expand Down