Skip to content

Commit f72fb10

Browse files
committed
Address comments
1 parent 610c8c6 commit f72fb10

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
250250
/// create sub vectors.
251251
/// 5. Insert the sub vectors back into the final vector.
252252
/// 6. Replace the original op with the new result.
253+
///
254+
/// Expects the operation to be unrolled to have at most 1 result. When there's
255+
/// no result, expects the caller to pass in the `vectorTy` to be able to get
256+
/// the unroll factor.
253257
using UnrollVectorOpFn =
254258
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
255259

mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
6565
Value maskVec = op.getMask();
6666
Value valueVec = op.getValueToStore();
6767

68-
// Get the vector type from one of the vector operands
68+
// Get the vector type from one of the vector operands.
6969
VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
7070
if (!vectorTy)
7171
return failure();
@@ -85,8 +85,8 @@ struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
8585
indexSubVec, maskSubVec, valueSubVec,
8686
op.getAlignmentAttr());
8787

88-
// Return a dummy value since unrollVectorOp expects a Value
89-
return rewriter.create<ub::PoisonOp>(loc, subTy);
88+
// Return a dummy value since unrollVectorOp expects a Value.
89+
return Value();
9090
};
9191

9292
return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -433,15 +433,12 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
433433
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
434434
vector::UnrollVectorOpFn unrollFn,
435435
VectorType vectorTy) {
436-
// If vector type is not provided, get it from the result
436+
// If vector type is not provided, get it from the result.
437437
if (!vectorTy) {
438-
if (op->getNumResults() != 1)
439-
return rewriter.notifyMatchFailure(
440-
op, "expected single result when vector type not provided");
441-
438+
assert(op->getNumResults() == 1 &&
439+
"expected single result when vector type not provided");
442440
vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
443-
if (!vectorTy)
444-
return rewriter.notifyMatchFailure(op, "expected vector type");
441+
assert(vectorTy && "expected result to have vector type");
445442
}
446443

447444
if (vectorTy.getRank() < 2)
@@ -454,7 +451,7 @@ LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
454451

455452
Location loc = op->getLoc();
456453

457-
// Only create result value if the operation produces results
454+
// Only create result value if the operation produces results.
458455
Value result;
459456
if (op->getNumResults() > 0) {
460457
result = ub::PoisonOp::create(rewriter, loc, vectorTy);
@@ -465,7 +462,7 @@ LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
465462
for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
466463
Value subVector = unrollFn(rewriter, loc, subTy, i);
467464

468-
// Only insert if we have a result to build
465+
// Only insert if we have a result to build.
469466
if (op->getNumResults() > 0) {
470467
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
471468
}

0 commit comments

Comments
 (0)