Skip to content

Commit b5afd2a

Browse files
authored
Sort Dimensions based on Layout in case of input fusion (#1793)
Fixes slowdown reported in ROCm/rocMLIR-internal#1784
1 parent 4a757de commit b5afd2a

File tree

3 files changed

+258
-12
lines changed

3 files changed

+258
-12
lines changed

mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,27 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Analysis/BufferDependencyAnalysis.h"
1314
#include "mlir/Dialect/Arith/IR/Arith.h"
1415
#include "mlir/Dialect/Func/IR/FuncOps.h"
1516
#include "mlir/Dialect/Linalg/IR/Linalg.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1618
#include "mlir/Dialect/Rock/IR/Rock.h"
1719
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
1820
#include "mlir/Dialect/Rock/Passes.h"
1921
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
2022
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
2123
#include "mlir/IR/PatternMatch.h"
2224
#include "mlir/IR/Value.h"
25+
#include "mlir/Interfaces/SideEffectInterfaces.h"
2326
#include "mlir/Support/LogicalResult.h"
2427
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28+
#include "llvm/ADT/DenseMap.h"
2529
#include "llvm/ADT/STLExtras.h"
30+
#include "llvm/Support/LogicalResult.h"
2631
#include <limits>
2732
#include <numeric>
33+
#include <optional>
2834

2935
namespace mlir {
3036
namespace rock {
@@ -65,25 +71,116 @@ FailureOr<Container> reorderArrayAttr(Container inputArray,
6571

6672
return reorderedElements;
6773
}
74+
//
75+
76+
// traces input arguments of the GEMM operation back to blockArguments. It
77+
// records sequence of rock.transforms between gemm argument to blockArgument
78+
// if there is any. It is possible that single gemm arg is mapped to multiple
79+
// blockArguments. BlockArguments are recorded in `blockArgs` and series of
80+
// rock.TransformAttr sequences for each `blockArgs` is recorded in
81+
// transformAttrsMap.
82+
static LogicalResult traceGemmInputToBlockArgs(
83+
Value inputArg, PatternRewriter &b,
84+
llvm::DenseMap<Value, SmallVector<Attribute>> &transformAttrsMap,
85+
llvm::SmallSetVector<Value, 2> &blockArgs,
86+
const BufferDependencyAnalysis &deps) {
87+
Value source;
88+
ArrayAttr transforms;
89+
// below call to `rock.untransform` is concatenating existing transform
90+
// sequence on `inputArg` with rock.transform sequence found by tracing upto
91+
// source from `inputArg` as staring point.
92+
// For example,
93+
// SeqExisting -> inputArgs --> Seq --> source
94+
// transforms == SeqExisting + Seq
95+
// transformAttrsMap[inputArg] = SeqExisting
96+
// transformAttrsMap[Source] = SeqExisting + Seq
97+
std::tie(source, transforms, std::ignore) =
98+
rock::untransform(b, inputArg, transformAttrsMap[inputArg]);
99+
// insert transform sequence on source into the map if it doesn't already
100+
// exists. if it does then we've found a loop or case where multiple operators
101+
// are writing to same `memref.alloc`
102+
if (!transformAttrsMap
103+
.insert({source, SmallVector<Attribute>{transforms.begin(),
104+
transforms.end()}})
105+
.second) {
106+
return failure();
107+
}
108+
if (isa<BlockArgument>(source)) {
109+
blockArgs.insert(source);
110+
return success();
111+
}
112+
FailureOr<memref::AllocOp> allocOp = mlir::rock::findMemrefAlloc(source);
113+
if (failed(allocOp)) {
114+
return failure();
115+
}
116+
std::optional<llvm::SmallVector<OpOperand *>> allocOpWriters =
117+
deps.getWriters(allocOp.value());
118+
if (!allocOpWriters.has_value()) {
119+
return failure();
120+
}
121+
bool hasSuccess = false;
122+
for (OpOperand *allocWriteOperand : allocOpWriters.value()) {
123+
auto writerOp =
124+
dyn_cast<MemoryEffectOpInterface>(allocWriteOperand->getOwner());
125+
if (!writerOp)
126+
continue;
127+
SmallVector<MemoryEffects::EffectInstance> effects;
128+
writerOp.getEffects(effects);
129+
for (const MemoryEffects::EffectInstance &effect : effects) {
130+
OpOperand *writerOpOperand = effect.getEffectValue<OpOperand *>();
131+
// test that same buffer is not being read and written to
132+
if (writerOpOperand && isa<MemoryEffects::Read>(effect.getEffect()) &&
133+
writerOpOperand != allocWriteOperand) {
134+
Value writerOpOperandValue = writerOpOperand->get();
135+
// Add existing transform sequences on `writerOpOperandValue` to
136+
// continue concatenating in recursive calls.
137+
transformAttrsMap[writerOpOperandValue] = transformAttrsMap.at(source);
138+
if (succeeded(traceGemmInputToBlockArgs(
139+
writerOpOperandValue, b, transformAttrsMap, blockArgs, deps))) {
140+
hasSuccess = true;
141+
}
142+
}
143+
}
144+
}
145+
// return success if it has found trace to any blockArg
146+
return success(hasSuccess);
147+
}
68148

69149
template <typename Container>
70150
static FailureOr<std::tuple<Value, Container, SmallVector<uint32_t>>>
71151
sortByMemoryLayout(Value tensor, const Container &layout, PatternRewriter &b) {
72-
ArrayAttr transforms;
73-
Value source;
74-
std::tie(source, transforms, std::ignore) = rock::untransform(b, tensor);
75-
76-
if (transforms.empty())
152+
// trace input tensor to blockArgument first and do necessary error checking
153+
llvm::DenseMap<Value, SmallVector<Attribute>> transformAttrsMap;
154+
llvm::SmallSetVector<Value, 2> blockArgs;
155+
BufferDependencyAnalysis deps(tensor.getParentBlock()->getParentOp());
156+
if (failed(traceGemmInputToBlockArgs(tensor, b, transformAttrsMap, blockArgs,
157+
deps))) {
77158
return std::make_tuple(tensor, layout, SmallVector<uint32_t>{});
78-
159+
}
160+
assert(!blockArgs.empty());
161+
SmallVector<Attribute> transformsList;
162+
for (const auto blockArg : blockArgs) {
163+
// make sure all the blockArgs have been mapped to some transform sequence
164+
// or empty transform sequence
165+
if (!transformAttrsMap.contains(blockArg)) {
166+
return std::make_tuple(tensor, layout, SmallVector<uint32_t>{});
167+
}
168+
if (transformsList.empty()) {
169+
transformsList = transformAttrsMap[blockArg];
170+
} else if (transformsList != transformAttrsMap[blockArg]) {
171+
// Currently we do not handle case where some block arg goes through
172+
// different sequence of transforms. All blockArgs must have same
173+
// transforms for now.
174+
return std::make_tuple(tensor, layout, SmallVector<uint32_t>{});
175+
}
176+
}
177+
if (transformsList.empty()) {
178+
return std::make_tuple(tensor, layout, SmallVector<uint32_t>{});
179+
}
180+
ArrayAttr transforms = b.getArrayAttr(transformsList);
79181
rock::TransformMapAttr firstCoordTransform =
80-
cast<rock::TransformMapAttr>(transforms[0]);
182+
cast<rock::TransformMapAttr>(transformsList[0]);
81183
int64_t upperRank = firstCoordTransform.getUpperBounds().size();
82-
83-
// no need to do anything if it's not a block argument
84-
if (!isa<BlockArgument>(source))
85-
return std::make_tuple(tensor, layout, SmallVector<uint32_t>{});
86-
87184
SmallVector<uint32_t> strides(upperRank);
88185
for (int64_t idx = 0; idx < upperRank; idx++) {
89186
FailureOr<llvm::SmallDenseMap<int64_t, SmallVector<rock::SubDimInfo>>>

0 commit comments

Comments
 (0)