Skip to content

Commit 36fa0e9

Browse files
committed
Address feedback
1 parent d59390b commit 36fa0e9

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -465,24 +465,27 @@ struct UnrollElementwisePattern : public RewritePattern {
465465
auto targetShape = getTargetShape(options, op);
466466
if (!targetShape)
467467
return failure();
468+
int64_t targetShapeSize = targetShape->size();
468469
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
469470
SmallVector<int64_t> originalSize =
470471
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472+
int64_t originalShapeSize = originalSize.size();
471473

472474
Location loc = op->getLoc();
473475

474476
// Handle rank mismatch by adding leading unit dimensions to targetShape
475-
SmallVector<int64_t> adjustedTargetShape = *targetShape;
476-
if (originalSize.size() > targetShape->size()) {
477-
// Add leading unit dimensions to targetShape
478-
int64_t rankDiff = originalSize.size() - targetShape->size();
479-
adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
480-
}
481-
477+
SmallVector<int64_t> adjustedTargetShape(originalShapeSize);
478+
int64_t rankDiff = originalShapeSize - targetShapeSize;
479+
std::fill(adjustedTargetShape.begin(),
480+
adjustedTargetShape.begin() + rankDiff, 1);
481+
std::copy(targetShape->begin(), targetShape->end(),
482+
adjustedTargetShape.begin() + rankDiff);
483+
484+
int64_t adjustedTargetShapeSize = adjustedTargetShape.size();
482485
// Prepare the result vector.
483486
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
484487
rewriter.getZeroAttr(dstVecType));
485-
SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
488+
SmallVector<int64_t> strides(adjustedTargetShapeSize, 1);
486489
VectorType computeVecType =
487490
VectorType::get(*targetShape, dstVecType.getElementType());
488491

@@ -500,7 +503,7 @@ struct UnrollElementwisePattern : public RewritePattern {
500503
loc, operand.get(), offsets, adjustedTargetShape, strides);
501504

502505
// Reshape to remove leading unit dims if needed
503-
if (adjustedTargetShape.size() > targetShape->size()) {
506+
if (adjustedTargetShapeSize > targetShapeSize) {
504507
extracted = rewriter.createOrFold<vector::ShapeCastOp>(
505508
loc, VectorType::get(*targetShape, vecType.getElementType()),
506509
extracted);
@@ -515,8 +518,8 @@ struct UnrollElementwisePattern : public RewritePattern {
515518

516519
// Use strides sized to targetShape for proper insertion
517520
SmallVector<int64_t> insertStrides =
518-
(adjustedTargetShape.size() > targetShape->size())
519-
? SmallVector<int64_t>(targetShape->size(), 1)
521+
(adjustedTargetShapeSize > targetShapeSize)
522+
? SmallVector<int64_t>(targetShapeSize, 1)
520523
: strides;
521524

522525
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -193,30 +193,31 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
193193
return %0 : vector<3x2x2xf32>
194194
}
195195
// CHECK-LABEL: func @vector_fma_3d
196+
// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
196197
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
197-
// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
198-
// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
199-
// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
200-
// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
201-
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
202-
// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
203-
// CHECK: %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
198+
// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
199+
// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
200+
// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
201+
// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
202+
// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
203+
// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
204+
// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
204205
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
205-
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
206-
// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
207-
// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
208-
// CHECK: %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
209-
// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
210-
// CHECK: %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
211-
// CHECK: %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
206+
// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
207+
// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
208+
// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
209+
// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
210+
// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
211+
// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
212+
// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
212213
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
213-
// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
214-
// CHECK: %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
215-
// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
216-
// CHECK: %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
217-
// CHECK: %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
218-
// CHECK: %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
219-
// CHECK: %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
214+
// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
215+
// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
216+
// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
217+
// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
218+
// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
219+
// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
220+
// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
220221
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
221222
// CHECK: return %[[I2]] : vector<3x2x2xf32>
222223

0 commit comments

Comments
 (0)