Skip to content

Commit 5ad484e

Browse files
committed
cosmetic fixes and better failure notifications
1 parent 31ec6c6 commit 5ad484e

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,25 @@ namespace {
7979
/// The greatest common divisor (gcd) of the first dimension preceding the
8080
/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
8181
/// on vectors with shapes that are `multiples` of (what we define as) the
82-
/// 'atomic size', 2x7x11. The atomic size is `gcd` x `common-suffix`.
82+
/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
8383
///
8484
/// vector<2x2x3x4x7x11xi8> to
8585
/// vector<8x6x7x11xi8>
8686
/// ^^^^ ---> common suffix of 7x11
8787
/// ^ ---> gcd(4,6) is 2 | |
8888
/// | | |
8989
/// v v v
90-
/// atomic size <----- 2x7x11
90+
/// atomic shape <----- 2x7x11
9191
///
9292
///
9393
///
94-
/// The decomposition implemented in this patterns consists of a sequence of
94+
/// The decomposition implemented in this pattern consists of a sequence of
9595
/// repeated steps:
9696
///
9797
/// (1) Extract vectors from the suffix of the source.
9898
/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
9999
///
100-
/// (2) Do extract_strided_slice down to the atomic size.
100+
/// (2) Do extract_strided_slice down to the atomic shape.
101101
/// In our example this is 4x7x11 -> 2x7x11.
102102
///
103103
/// (3) Do insert_strided_slice to the suffix of the result.
@@ -130,7 +130,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
130130

131131
if (sourceType.isScalable() || resultType.isScalable())
132132
return rewriter.notifyMatchFailure(
133-
op, "shape_cast lowering not handled by this pattern");
133+
op,
134+
"shape_cast where vectors are scalable not handled by this pattern");
134135

135136
const ArrayRef<int64_t> sourceShape = sourceType.getShape();
136137
const ArrayRef<int64_t> resultShape = resultType.getShape();
@@ -332,7 +333,8 @@ class ScalableShapeCastOpRewritePattern
332333
// from >= 2-D scalable vectors or scalable vectors of fixed vectors.
333334
if (!isTrailingDimScalable(sourceVectorType) ||
334335
!isTrailingDimScalable(resultVectorType)) {
335-
return failure();
336+
return rewriter.notifyMatchFailure(
337+
op, "trailing dims are not scalable, not handled by this pattern");
336338
}
337339

338340
// The sizes of the trailing dimension of the source and result vectors, the

0 commit comments

Comments
 (0)