Skip to content

Commit 8c6f310

Browse files
committed
Address Feedback
1 parent 983c12b commit 8c6f310

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,32 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809809
vector::UnrollVectorOptions options;
810810
};
811811

812+
/// This pattern unrolls `vector.step` operations according to the provided
813+
/// target unroll shape. It decomposes a large step vector into smaller step
814+
/// vectors (segments) and assembles the result by inserting each computed
815+
/// segment into the appropriate offset of the original vector.
816+
///
817+
/// The pattern does not support scalable vectors and will fail to match them.
818+
///
819+
/// For each segment, it adds the base step vector and the segment's offset,
820+
/// then inserts the result into the output vector at the corresponding
821+
/// position.
822+
///
823+
/// Example:
824+
/// Given a step operation:
825+
/// %0 = vector.step : vector<8xindex>
826+
///
827+
/// and a target unroll shape of <4>, the pattern produces:
828+
///
829+
/// %base = vector.step : vector<4xindex>
830+
/// %zero = arith.constant dense<0> : vector<4xindex>
831+
/// %result0 = vector.insert_strided_slice %base, %zero
832+
/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
833+
/// %offset = arith.constant dense<4> : vector<4xindex>
834+
/// %segment1 = arith.addi %base, %offset : vector<4xindex>
835+
/// %result1 = vector.insert_strided_slice %segment1, %result0
836+
/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
837+
///
812838
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
813839
UnrollStepPattern(MLIRContext *context,
814840
const vector::UnrollVectorOptions &options,
@@ -817,7 +843,8 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
817843

818844
LogicalResult matchAndRewrite(vector::StepOp stepOp,
819845
PatternRewriter &rewriter) const override {
820-
auto targetShape = getTargetShape(options, stepOp);
846+
std::optional<SmallVector<int64_t>> targetShape =
847+
getTargetShape(options, stepOp);
821848
if (!targetShape)
822849
return failure();
823850

@@ -833,18 +860,18 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
833860
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
834861
rewriter.getZeroAttr(vecType));
835862

863+
VectorType targetVecType =
864+
VectorType::get(*targetShape, vecType.getElementType());
865+
Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
836866
for (SmallVector<int64_t> offsets :
837867
StaticTileOffsetRange({originalSize}, *targetShape)) {
838-
int64_t tileOffset = offsets[0];
839-
auto targetVecType =
840-
VectorType::get(*targetShape, vecType.getElementType());
841-
Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
842-
Value offsetVal =
843-
rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
844-
Value bcastOffset =
845-
rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
868+
Value bcastOffset = arith::ConstantOp::create(
869+
rewriter, loc, targetVecType,
870+
DenseElementsAttr::get(
871+
targetVecType,
872+
IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
846873
Value tileStep =
847-
rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
874+
arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
848875

849876
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
850877
loc, tileStep, result, offsets, strides);

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -427,19 +427,16 @@ func.func @vector_step() -> vector<32xindex> {
427427
return %0 : vector<32xindex>
428428
}
429429
// CHECK-LABEL: func @vector_step
430-
// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
431-
// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
430+
// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex>
431+
// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex>
432432
// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
433-
// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
434-
// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
435-
// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
436-
// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
437-
// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
433+
// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex>
434+
// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex>
435+
// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
436+
// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex>
438437
// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
439-
// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
440-
// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
438+
// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex>
441439
// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
442-
// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
443-
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
440+
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
444441
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
445442
// CHECK: return %[[INS3]] : vector<32xindex>

0 commit comments

Comments
 (0)