Skip to content

Commit 496f306

Browse files
committed
fix a bug
1 parent df06eea commit 496f306

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,23 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
658658
for (SmallVector<int64_t> offsets :
659659
StaticTileOffsetRange(originalShape, *targetShape)) {
660660
Value newSrc;
661-
// Scalar to vector broadcast.
662661
if (!srcType) {
662+
// Scalar to vector broadcast.
663663
newSrc = broadcastOp.getSource();
664664
} else {
665+
// Vector to vector broadcast.
665666
int64_t rank = srcType.getRank();
666-
auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).take_back(rank);
667-
auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).take_back(rank);
668-
auto srcStrides = llvm::ArrayRef<int64_t>(strides).take_back(rank);
667+
SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
668+
SmallVector<int64_t> srcShape(targetShape->end() - rank,
669+
targetShape->end());
670+
SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
671+
// addjust the offset and shape for src if the corresponding dim is 1.
672+
for (int64_t i = 0; i < rank; ++i) {
673+
if (srcType.getDimSize(i) == 1) {
674+
srcOffsets[i] = 0;
675+
srcShape[i] = 1;
676+
}
677+
}
669678
newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
670679
loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
671680
}

0 commit comments

Comments
 (0)