Skip to content

Commit a959b60

Browse files
committed
review comments
1 parent cfaef9d commit a959b60

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ struct ArithBuilder {
121121
};
122122

123123
/// ArithBuilder specialized specifically for tensor/memref indexing
124-
/// calculations. Those calculations generally should never signed overflow, so
125-
/// we can set oveflow flags accordingly.
124+
/// calculations. Those calculations generally should never signed overflow and
125+
/// always use signed integers, so we can set oveflow flags accordingly.
126126
struct ArithIndexingBuilder : public ArithBuilder {
127127
ArithIndexingBuilder(OpBuilder &b, Location loc)
128128
: ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
458458
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
459459
let description = [{
460460
Patterns that remove redundant Vector Ops by re-ordering them with
461-
e.g. elementwise Ops:
461+
e.g. elementwise Ops.
462+
463+
Example:
462464
```
463465
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
464466
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
@@ -480,8 +482,11 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
480482
"apply_patterns.vector.sink_mem_ops",
481483
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
482484
let description = [{
483-
Patterns that remove redundant Vector Ops by merging them with load/store
484-
ops
485+
Patterns that replace redundant Vector Ops (followed by
486+
`vector.load`/`vector.store`) with either vector.load/vector.store or
487+
`memref.load`/`memref.store`. Currently limited to 1-element vectors.
488+
489+
Example:
485490
```
486491
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
487492
vector.extract %0[1] : f32 from vector<4xf32>

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,15 +1047,18 @@ class ExtractOpFromElementwise final
10471047
}
10481048
};
10491049

1050+
/// Check if the element type is suitable for vector.load/store sinking.
1051+
/// Element type must be index or byte-aligned integer or floating-point type.
10501052
static bool isSupportedMemSinkElementType(Type type) {
10511053
if (isa<IndexType>(type))
10521054
return true;
10531055

1054-
// Non-byte-aligned types are tricky, skip them.
10551056
return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
10561057
}
10571058

1058-
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
1059+
/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1060+
/// Only index and byte-aligned integer and floating-point element types are
1061+
/// supported for now.
10591062
///
10601063
/// Example:
10611064
/// ```
@@ -1088,8 +1091,11 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
10881091
"scalable vectors are not supported");
10891092

10901093
MemRefType memType = loadOp.getMemRefType();
1094+
1095+
// Non-byte-aligned types are tricky and may require special handling,
1096+
// ignore them for now.
10911097
if (!isSupportedMemSinkElementType(memType.getElementType()))
1092-
return rewriter.notifyMatchFailure(op, "unsupported memref element type");
1098+
return rewriter.notifyMatchFailure(op, "unsupported element type");
10931099

10941100
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
10951101
if (rankOffset < 0)
@@ -1161,7 +1167,7 @@ class StoreOpFromSplatOrBroadcast final
11611167

11621168
if (vecType.getNumElements() != 1)
11631169
return rewriter.notifyMatchFailure(
1164-
op, "only 1-element, vectors are supported");
1170+
op, "only 1-element vectors are supported");
11651171

11661172
Operation *splat = op.getValueToStore().getDefiningOp();
11671173
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
@@ -2253,6 +2259,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
22532259

22542260
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
22552261
PatternBenefit benefit) {
2262+
// TODO: Consider converting these patterns to canonicalizations.
22562263
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
22572264
patterns.getContext(), benefit);
22582265
}

0 commit comments

Comments
 (0)