Skip to content

Commit 6cb9c3e

Browse files
committed
fix typo and add unit dim test case
1 parent 496f306 commit 6cb9c3e

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
647647
Location loc = broadcastOp.getLoc();
648648
VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
649649
VectorType resType = broadcastOp.getResultVectorType();
650-
VectorType newType =
650+
VectorType targetType =
651651
resType.cloneWith(*targetShape, resType.getElementType());
652652
Value result = rewriter.create<arith::ConstantOp>(
653653
loc, resType, rewriter.getZeroAttr(resType));
@@ -668,7 +668,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
668668
SmallVector<int64_t> srcShape(targetShape->end() - rank,
669669
targetShape->end());
670670
SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
671-
// addjust the offset and shape for src if the corresponding dim is 1.
671+
// adjust the offset and shape for src if the corresponding dim is 1.
672672
for (int64_t i = 0; i < rank; ++i) {
673673
if (srcType.getDimSize(i) == 1) {
674674
srcOffsets[i] = 0;
@@ -680,7 +680,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
680680
}
681681

682682
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
683-
newSrc, newType);
683+
newSrc, targetType);
684684

685685
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
686686
loc, newOp->getResult(0), result, offsets, strides);

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,4 +333,48 @@ func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
333333
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
334334
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
335335
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
336-
// CHECK: return [[r3]]
336+
// CHECK: return [[r3]] : vector<4x4xf32>
337+
338+
func.func @vector_broadcast_with_leading_unit_dim(%v: vector<1x4xf32>) -> vector<4x4xf32> {
339+
%0 = vector.broadcast %v : vector<1x4xf32> to vector<4x4xf32>
340+
return %0 : vector<4x4xf32>
341+
}
342+
343+
// CHECK-LABEL: func.func @vector_broadcast_with_leading_unit_dim
344+
// CHECK-SAME: ([[arg0:%.+]]: vector<1x4xf32>) -> vector<4x4xf32> {
345+
// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
346+
// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
347+
// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<1x2xf32> to vector<2x2xf32>
348+
// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
349+
// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
350+
// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<1x2xf32> to vector<2x2xf32>
351+
// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
352+
// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
353+
// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<1x2xf32> to vector<2x2xf32>
354+
// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
355+
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
356+
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<1x2xf32> to vector<2x2xf32>
357+
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
358+
// CHECK: return [[r3]] : vector<4x4xf32>
359+
360+
func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector<4x4xf32> {
361+
%0 = vector.broadcast %v : vector<4x1xf32> to vector<4x4xf32>
362+
return %0 : vector<4x4xf32>
363+
}
364+
365+
// CHECK-LABEL: func.func @vector_broadcast_with_tailing_unit_dim
366+
// CHECK-SAME: ([[arg0:%.+]]: vector<4x1xf32>) -> vector<4x4xf32> {
367+
// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
368+
// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
369+
// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2x1xf32> to vector<2x2xf32>
370+
// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
371+
// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
372+
// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2x1xf32> to vector<2x2xf32>
373+
// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
374+
// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
375+
// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2x1xf32> to vector<2x2xf32>
376+
// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
377+
// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
378+
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
379+
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
380+
// CHECK: return [[r3]] : vector<4x4xf32>

0 commit comments

Comments
 (0)