Skip to content

Commit 4789d11

Browse files
committed
address comements 2/2
1 parent 3709f67 commit 4789d11

File tree

14 files changed

+205
-218
lines changed

14 files changed

+205
-218
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
147147
Pure,
148148
ViewLikeOpInterface,
149149
SameOperandsAndResultType,
150-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
150+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
151151
]> {
152152
let summary =
153153
"assumption that gives alignment information to the input memref";
@@ -458,7 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
458458
def MemRef_CastOp : MemRef_Op<"cast", [
459459
DeclareOpInterfaceMethods<CastOpInterface>,
460460
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
461-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
461+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
462462
MemRefsNormalizable,
463463
Pure,
464464
SameOperandsAndResultShape,
@@ -1197,7 +1197,7 @@ def LoadOp : MemRef_Op<"load",
11971197
"memref", "result",
11981198
"::llvm::cast<MemRefType>($_self).getElementType()">,
11991199
MemRefsNormalizable,
1200-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1200+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
12011201
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
12021202
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
12031203
let summary = "load operation";
@@ -1381,7 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
13811381
def MemRef_ReinterpretCastOp
13821382
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
13831383
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1384-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1384+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
13851385
AttrSizedOperandSegments,
13861386
MemRefsNormalizable,
13871387
Pure,
@@ -1609,7 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
16091609

16101610
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
16111611
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1612-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1612+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
16131613
Pure,
16141614
ViewLikeOpInterface]> {
16151615
let summary = "memref reshape operation";
@@ -1708,7 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17081708

17091709
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17101710
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1711-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1711+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
17121712
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
17131713
let summary = "operation to produce a memref with a higher rank.";
17141714
let description = [{
@@ -1831,7 +1831,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
18311831

18321832
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
18331833
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1834-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
1834+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
18351835
]> {
18361836
let summary = "operation to produce a memref with a smaller rank.";
18371837
let description = [{
@@ -1939,7 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19391939
"memref", "value",
19401940
"::llvm::cast<MemRefType>($_self).getElementType()">,
19411941
MemRefsNormalizable,
1942-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1942+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
19431943
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19441944
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
19451945
let summary = "store operation";
@@ -2017,7 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20172017

20182018
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20192019
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2020-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
2020+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20212021
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20222022
AttrSizedOperandSegments,
20232023
OffsetSizeAndStrideOpInterface,
@@ -2293,7 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
22932293

22942294
def MemRef_TransposeOp : MemRef_Op<"transpose", [
22952295
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2296-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
2296+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
22972297
Pure]>,
22982298
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
22992299
Results<(outs AnyStridedMemRef)> {
@@ -2329,7 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
23292329

23302330
def MemRef_ViewOp : MemRef_Op<"view", [
23312331
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2332-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
2332+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
23332333
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
23342334
Pure]> {
23352335
let summary = "memref view operation";
@@ -2406,7 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
24062406
//===----------------------------------------------------------------------===//
24072407

24082408
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
2409-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
2409+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
24102410
AllTypesMatch<["value", "result"]>,
24112411
TypesMatchWith<"value type matches element type of memref",
24122412
"memref", "value",

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,7 @@ def Vector_TransferReadOp :
12471247
DeclareOpInterfaceMethods<MaskableOpInterface>,
12481248
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
12491249
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
1250-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1250+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
12511251
AttrSizedOperandSegments,
12521252
DestinationStyleOpInterface
12531253
]>,
@@ -1495,7 +1495,7 @@ def Vector_TransferWriteOp :
14951495
DeclareOpInterfaceMethods<MaskableOpInterface>,
14961496
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
14971497
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
1498-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
1498+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
14991499
AttrSizedOperandSegments,
15001500
DestinationStyleOpInterface
15011501
]>,
@@ -1652,7 +1652,7 @@ def Vector_TransferWriteOp :
16521652

16531653
def Vector_LoadOp : Vector_Op<"load", [
16541654
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1655-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
1655+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
16561656
]> {
16571657
let summary = "reads an n-D slice of memory into an n-D vector";
16581658
let description = [{
@@ -1769,7 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
17691769

17701770
def Vector_StoreOp : Vector_Op<"store", [
17711771
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1772-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
1772+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
17731773
]> {
17741774
let summary = "writes an n-D vector to an n-D slice of memory";
17751775
let description = [{
@@ -1874,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
18741874
}
18751875

18761876
def Vector_MaskedLoadOp :
1877-
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
1877+
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
18781878
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18791879
Variadic<Index>:$indices,
18801880
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1966,7 +1966,7 @@ def Vector_MaskedLoadOp :
19661966
}
19671967

19681968
def Vector_MaskedStoreOp :
1969-
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
1969+
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
19701970
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19711971
Variadic<Index>:$indices,
19721972
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2046,7 +2046,7 @@ def Vector_MaskedStoreOp :
20462046
def Vector_GatherOp :
20472047
Vector_Op<"gather", [
20482048
DeclareOpInterfaceMethods<MaskableOpInterface>,
2049-
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
2049+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20502050
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
20512051
]>,
20522052
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2150,7 +2150,7 @@ def Vector_GatherOp :
21502150
}
21512151

21522152
def Vector_ScatterOp :
2153-
Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
2153+
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
21542154
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21552155
Variadic<Index>:$offsets,
21562156
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2235,7 +2235,7 @@ def Vector_ScatterOp :
22352235
}
22362236

22372237
def Vector_ExpandLoadOp :
2238-
Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
2238+
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
22392239
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22402240
Variadic<Index>:$indices,
22412241
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2323,7 +2323,7 @@ def Vector_ExpandLoadOp :
23232323
}
23242324

23252325
def Vector_CompressStoreOp :
2326-
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
2326+
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
23272327
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
23282328
Variadic<Index>:$indices,
23292329
FixedVectorOfNonZeroRankOf<[I1]>:$mask,

mlir/include/mlir/Interfaces/MemOpInterfaces.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ namespace detail {
2121
/// Attempt to verify the given memory space cast operation.
2222
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
2323

24-
/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
25-
/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
26-
/// true. It returns failure if `operand` doesn't reference a
24+
/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
25+
/// referenced by `operand`. On success, it returns `results` and true. It
26+
/// returns failure if `operand` doesn't reference a
2727
/// `MemorySpaceCastOpInterface` op.
28-
FailureOr<SmallVector<Value>>
29-
fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
30-
bool &modifiedInPlace);
28+
FailureOr<std::pair<SmallVector<Value>, bool>>
29+
bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
3130
} // namespace detail
3231
} // namespace mlir
3332

mlir/include/mlir/Interfaces/MemOpInterfaces.td

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,26 @@
1616
include "mlir/IR/OpBase.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
1818

19-
def FuseMemorySpaceCastConsumerOpInterface :
20-
OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
19+
def MemorySpaceCastConsumerOpInterface :
20+
OpInterface<"MemorySpaceCastConsumerOpInterface"> {
2121
let description = [{
22-
An interface to fuse memory-space cast operands into a consumer operation.
23-
It is the responsibility of the interface to determine which casts can be
24-
fused into the operation.
22+
An interface for operations that can consume memory-space cast-like
23+
operations.
2524
}];
2625
let cppNamespace = "::mlir";
2726
let methods = [
2827
InterfaceMethod<[{
29-
Attempt to fuse the incoming cast-like operands. Returns `success`
30-
and any new results on fusion success, otherwise it returns failure.
28+
Attempt to bubble-down the incoming cast-like operands. On success
29+
returns any new results, and whether the operation was modified in
30+
place, otherwise it returns failure.
3131
If new results are produced, these must be compatible with the original
3232
operation results.
3333

34-
The `modifiedInPlace` parameter indicates whether the operation was
35-
modified in place. If `false` and the fusion succeeded, then the
36-
interface guarantees it is valid to erase the original operation.
37-
If `true`, then the interface must guarantee no operations were created
38-
by the method, and that no further IR modification is necessary. It is
39-
considered an error if `modifiedInPlace` is true and the fusion failed.
34+
If the operation was not modified in place, then the interface
35+
guarantees it is valid to erase the original operation.
36+
If the operation was modified in place, then the interface must
37+
guarantee no operations were created by the method, and that no further
38+
IR modification is necessary.
4039

4140
Any implementations of this method must not erase/replace the original
4241
operation, instead it is the caller responsibility to erase or replace
@@ -45,8 +44,9 @@ def FuseMemorySpaceCastConsumerOpInterface :
4544
Finally, any implementations of this method have to guarantee that the
4645
IR remains valid at all times.
4746
}],
48-
"::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
49-
(ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
47+
"::llvm::FailureOr<std::pair<::llvm::SmallVector<::mlir::Value>, bool>>",
48+
"bubbleDownCasts",
49+
(ins "::mlir::OpBuilder &":$builder)
5050
>,
5151
];
5252
}
@@ -83,13 +83,16 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
8383
Clones the memory space cast op with the given source and target type.
8484
}],
8585
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
86-
(ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
86+
(ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
8787
"::mlir::Value":$src)
8888
>,
8989
InterfaceMethod<[{
90-
Returns whether the cast allows to be fused.
90+
Returns whether the memory-space cast is lossless. A lossless
91+
memory-space cast must not lose any information encoded in the memory
92+
space. An example of such cast, is any conversion to the generic memory
93+
space.
9194
}],
92-
"bool", "isFusableMemorySpaceCast"
95+
"bool", "isLosslessCast"
9396
>
9497
];
9598
let verify = [{
@@ -99,12 +102,12 @@ def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
99102
let extraClassDeclaration = [{
100103
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
101104
/// is produced by a `MemorySpaceCastOpInterface` op, and
102-
/// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
105+
/// `isLosslessCast` returns true, otherwise it returns null.
103106
static ::mlir::MemorySpaceCastOpInterface
104-
getIfFusableCast(::mlir::Value value) {
107+
getIfLosslessCast(::mlir::Value value) {
105108
auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
106109
value.getDefiningOp());
107-
if (!op || !op.isFusableMemorySpaceCast())
110+
if (!op || !op.isLosslessCast())
108111
return nullptr;
109112
return op;
110113
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===-- BubbleDownMemorySpaceCasts.h - Bubble down cast patterns ---C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
10+
#define MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H
11+
12+
namespace mlir {
13+
class PatternBenefit;
14+
class RewritePatternSet;
15+
/// Collect a set of patterns to bubble-down memory-space cast operations.
16+
void populateBubbleDownMemorySpaceCastPatterns(RewritePatternSet &patterns,
17+
PatternBenefit benefit);
18+
} // namespace mlir
19+
20+
#endif // MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H

mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h

Lines changed: 0 additions & 20 deletions
This file was deleted.

mlir/include/mlir/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class GreedyRewriteConfig;
4646
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
4747
#define GEN_PASS_DECL_TOPOLOGICALSORT
4848
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
49-
#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTOCONSUMERS
49+
#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
5050
#include "mlir/Transforms/Passes.h.inc"
5151

5252
/// Creates an instance of the Canonicalizer pass, configured with default

mlir/include/mlir/Transforms/Passes.td

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -585,14 +585,13 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
585585
];
586586
}
587587

588-
def FuseMemorySpaceCastsIntoConsumers :
589-
Pass<"fuse-memory-space-casts-into-consumers"> {
590-
let summary = "Fuses memory-space cast operations into consumers.";
588+
def BubbleDownMemorySpaceCasts :
589+
Pass<"bubble-down-memory-space-casts"> {
590+
let summary = "Bubbles down memory-space cast operations.";
591591
let description = [{
592-
This pass tries to iteratively fuse all possible memory-space cast
593-
operations into their consumers. It does this by looking for
594-
`FuseMemorySpaceCastConsumerOpInterface` operations, and invoking the
595-
interface methods to perform the fusion.
592+
This pass tries to iteratively bubble down all possible memory-space cast
593+
operations. It does this by looking for `MemorySpaceCastConsumerOpInterface`
594+
operations, and invoking the interface methods to perform the bubbling down.
596595

597596
Example:
598597

@@ -609,7 +608,7 @@ def FuseMemorySpaceCastsIntoConsumers :
609608
%atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
610609
return %collapsed : memref<16xf32>
611610
}
612-
// mlir-opt --fuse-casts-into-consumers
611+
// mlir-opt --bubble-down-memory-space-casts
613612
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
614613
%c4 = arith.constant 4 : index
615614
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)