|
10 | 10 | // |
11 | 11 | //===----------------------------------------------------------------------===// |
12 | 12 |
|
| 13 | +#include "mlir/Analysis/BufferDependencyAnalysis.h" |
13 | 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | 15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | 16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 17 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | 18 | #include "mlir/Dialect/Rock/IR/Rock.h" |
17 | 19 | #include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" |
18 | 20 | #include "mlir/Dialect/Rock/Passes.h" |
19 | 21 | #include "mlir/Dialect/Rock/utility/loweringUtils.h" |
20 | 22 | #include "mlir/Dialect/Rock/utility/transformMapUtils.h" |
21 | 23 | #include "mlir/IR/PatternMatch.h" |
22 | 24 | #include "mlir/IR/Value.h" |
| 25 | +#include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | 26 | #include "mlir/Support/LogicalResult.h" |
24 | 27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 28 | +#include "llvm/ADT/DenseMap.h" |
25 | 29 | #include "llvm/ADT/STLExtras.h" |
| 30 | +#include "llvm/Support/LogicalResult.h" |
26 | 31 | #include <limits> |
27 | 32 | #include <numeric> |
| 33 | +#include <optional> |
28 | 34 |
|
29 | 35 | namespace mlir { |
30 | 36 | namespace rock { |
@@ -65,25 +71,116 @@ FailureOr<Container> reorderArrayAttr(Container inputArray, |
65 | 71 |
|
66 | 72 | return reorderedElements; |
67 | 73 | } |
| 74 | +// |
| 75 | + |
| 76 | +// traces input arguments of the GEMM operation back to blockArguments. It |
| 77 | +// records sequence of rock.transforms between gemm argument to blockArgument |
| 78 | +// if there is any. It is possible that single gemm arg is mapped to multiple |
| 79 | +// blockArguments. BlockArguments are recorded in `blockArgs` and series of |
| 80 | +// rock.TransformAttr sequences for each `blockArgs` is recorded in |
| 81 | +// transformAttrsMap. |
| 82 | +static LogicalResult traceGemmInputToBlockArgs( |
| 83 | + Value inputArg, PatternRewriter &b, |
| 84 | + llvm::DenseMap<Value, SmallVector<Attribute>> &transformAttrsMap, |
| 85 | + llvm::SmallSetVector<Value, 2> &blockArgs, |
| 86 | + const BufferDependencyAnalysis &deps) { |
| 87 | + Value source; |
| 88 | + ArrayAttr transforms; |
| 89 | + // below call to `rock.untransform` is concatenating existing transform |
| 90 | + // sequence on `inputArg` with rock.transform sequence found by tracing upto |
| 91 | + // source from `inputArg` as staring point. |
| 92 | + // For example, |
| 93 | + // SeqExisting -> inputArgs --> Seq --> source |
| 94 | + // transforms == SeqExisting + Seq |
| 95 | + // transformAttrsMap[inputArg] = SeqExisting |
| 96 | + // transformAttrsMap[Source] = SeqExisting + Seq |
| 97 | + std::tie(source, transforms, std::ignore) = |
| 98 | + rock::untransform(b, inputArg, transformAttrsMap[inputArg]); |
| 99 | + // insert transform sequence on source into the map if it doesn't already |
| 100 | + // exists. if it does then we've found a loop or case where multiple operators |
| 101 | + // are writing to same `memref.alloc` |
| 102 | + if (!transformAttrsMap |
| 103 | + .insert({source, SmallVector<Attribute>{transforms.begin(), |
| 104 | + transforms.end()}}) |
| 105 | + .second) { |
| 106 | + return failure(); |
| 107 | + } |
| 108 | + if (isa<BlockArgument>(source)) { |
| 109 | + blockArgs.insert(source); |
| 110 | + return success(); |
| 111 | + } |
| 112 | + FailureOr<memref::AllocOp> allocOp = mlir::rock::findMemrefAlloc(source); |
| 113 | + if (failed(allocOp)) { |
| 114 | + return failure(); |
| 115 | + } |
| 116 | + std::optional<llvm::SmallVector<OpOperand *>> allocOpWriters = |
| 117 | + deps.getWriters(allocOp.value()); |
| 118 | + if (!allocOpWriters.has_value()) { |
| 119 | + return failure(); |
| 120 | + } |
| 121 | + bool hasSuccess = false; |
| 122 | + for (OpOperand *allocWriteOperand : allocOpWriters.value()) { |
| 123 | + auto writerOp = |
| 124 | + dyn_cast<MemoryEffectOpInterface>(allocWriteOperand->getOwner()); |
| 125 | + if (!writerOp) |
| 126 | + continue; |
| 127 | + SmallVector<MemoryEffects::EffectInstance> effects; |
| 128 | + writerOp.getEffects(effects); |
| 129 | + for (const MemoryEffects::EffectInstance &effect : effects) { |
| 130 | + OpOperand *writerOpOperand = effect.getEffectValue<OpOperand *>(); |
| 131 | + // test that same buffer is not being read and written to |
| 132 | + if (writerOpOperand && isa<MemoryEffects::Read>(effect.getEffect()) && |
| 133 | + writerOpOperand != allocWriteOperand) { |
| 134 | + Value writerOpOperandValue = writerOpOperand->get(); |
| 135 | + // Add existing transform sequences on `writerOpOperandValue` to |
| 136 | + // continue concatenating in recursive calls. |
| 137 | + transformAttrsMap[writerOpOperandValue] = transformAttrsMap.at(source); |
| 138 | + if (succeeded(traceGemmInputToBlockArgs( |
| 139 | + writerOpOperandValue, b, transformAttrsMap, blockArgs, deps))) { |
| 140 | + hasSuccess = true; |
| 141 | + } |
| 142 | + } |
| 143 | + } |
| 144 | + } |
| 145 | + // return success if it has found trace to any blockArg |
| 146 | + return success(hasSuccess); |
| 147 | +} |
68 | 148 |
|
69 | 149 | template <typename Container> |
70 | 150 | static FailureOr<std::tuple<Value, Container, SmallVector<uint32_t>>> |
71 | 151 | sortByMemoryLayout(Value tensor, const Container &layout, PatternRewriter &b) { |
72 | | - ArrayAttr transforms; |
73 | | - Value source; |
74 | | - std::tie(source, transforms, std::ignore) = rock::untransform(b, tensor); |
75 | | - |
76 | | - if (transforms.empty()) |
| 152 | + // trace input tensor to blockArgument first and do necessary error checking |
| 153 | + llvm::DenseMap<Value, SmallVector<Attribute>> transformAttrsMap; |
| 154 | + llvm::SmallSetVector<Value, 2> blockArgs; |
| 155 | + BufferDependencyAnalysis deps(tensor.getParentBlock()->getParentOp()); |
| 156 | + if (failed(traceGemmInputToBlockArgs(tensor, b, transformAttrsMap, blockArgs, |
| 157 | + deps))) { |
77 | 158 | return std::make_tuple(tensor, layout, SmallVector<uint32_t>{}); |
78 | | - |
| 159 | + } |
| 160 | + assert(!blockArgs.empty()); |
| 161 | + SmallVector<Attribute> transformsList; |
| 162 | + for (const auto blockArg : blockArgs) { |
| 163 | + // make sure all the blockArgs have been mapped to some transform sequence |
| 164 | + // or empty transform sequence |
| 165 | + if (!transformAttrsMap.contains(blockArg)) { |
| 166 | + return std::make_tuple(tensor, layout, SmallVector<uint32_t>{}); |
| 167 | + } |
| 168 | + if (transformsList.empty()) { |
| 169 | + transformsList = transformAttrsMap[blockArg]; |
| 170 | + } else if (transformsList != transformAttrsMap[blockArg]) { |
| 171 | + // Currently we do not handle case where some block arg goes through |
| 172 | + // different sequence of transforms. All blockArgs must have same |
| 173 | + // transforms for now. |
| 174 | + return std::make_tuple(tensor, layout, SmallVector<uint32_t>{}); |
| 175 | + } |
| 176 | + } |
| 177 | + if (transformsList.empty()) { |
| 178 | + return std::make_tuple(tensor, layout, SmallVector<uint32_t>{}); |
| 179 | + } |
| 180 | + ArrayAttr transforms = b.getArrayAttr(transformsList); |
79 | 181 | rock::TransformMapAttr firstCoordTransform = |
80 | | - cast<rock::TransformMapAttr>(transforms[0]); |
| 182 | + cast<rock::TransformMapAttr>(transformsList[0]); |
81 | 183 | int64_t upperRank = firstCoordTransform.getUpperBounds().size(); |
82 | | - |
83 | | - // no need to do anything if it's not a block argument |
84 | | - if (!isa<BlockArgument>(source)) |
85 | | - return std::make_tuple(tensor, layout, SmallVector<uint32_t>{}); |
86 | | - |
87 | 184 | SmallVector<uint32_t> strides(upperRank); |
88 | 185 | for (int64_t idx = 0; idx < upperRank; idx++) { |
89 | 186 | FailureOr<llvm::SmallDenseMap<int64_t, SmallVector<rock::SubDimInfo>>> |
|
0 commit comments