|
1 | 1 | #ifndef TRITON_INTEL_UTILS_DEFUSECHAIN_H |
2 | 2 | #define TRITON_INTEL_UTILS_DEFUSECHAIN_H |
3 | 3 |
|
| 4 | +#include "Utils/Utility.h" |
4 | 5 | #include "mlir/IR/Value.h" |
| 6 | +#include "mlir/Interfaces/LoopLikeInterface.h" |
5 | 7 | #include "llvm/ADT/SetVector.h" |
6 | 8 | #include <unordered_set> |
7 | 9 |
|
@@ -97,6 +99,72 @@ class DefUseChainManager { |
97 | 99 | DefUseChains chains; |
98 | 100 | }; |
99 | 101 |
|
| 102 | +/// \class Fuser |
| 103 | +/// Abstract base class providing functionality to fuse operations within a |
| 104 | +/// set of def-use chains. |
| 105 | +class Fuser { |
| 106 | +protected: |
| 107 | + SmallPtrSet<Operation *, 8> cleanUp; |
| 108 | + |
| 109 | + virtual ~Fuser() { |
| 110 | + if (!cleanUp.empty()) |
| 111 | + eraseOperations(cleanUp); |
| 112 | + } |
| 113 | + |
| 114 | + using DefUseChain = intel::DefUseChain; |
| 115 | + using DefUseChainManager = intel::DefUseChainManager; |
| 116 | + using DefUseChains = DefUseChainManager::DefUseChains; |
| 117 | + |
| 118 | + // Delegate to derived classes details on which operations within a |
| 119 | + // DefUseChain to fuse. |
| 120 | + virtual void fuse(const DefUseChain &) = 0; |
| 121 | + |
| 122 | + // Fuse operations in the given \p chains. |
| 123 | + void fuse(const DefUseChains &chains); |
| 124 | + |
| 125 | + // Duplicate the root operation of the given \p chains. |
| 126 | + void duplicateRoot(DefUseChains &chains) const; |
| 127 | + |
| 128 | + // Duplicate the root operation of \p sameRootChains and update \p chains. |
| 129 | + void duplicateRoot(DefUseChains &sameRootChains, DefUseChains &chains) const; |
| 130 | + |
| 131 | + // Prune \p chains that cannot be handled during fusion. For example, |
| 132 | + // operations in the def-use chain should have a single user, except in |
| 133 | + // special circumstances (e.g. the root operation of a chain might have more |
| 134 | + // than one user). |
| 135 | + void pruneInvalid(DefUseChains &chains) const; |
| 136 | + |
| 137 | + // Determine whether all operations in the given def-use \p chain have a |
| 138 | + // single user. Note: we allow an operation in the def-use chain to have an |
| 139 | + // additional user if the operation is in a for loop, and the additional user |
| 140 | + // is the loop yield operation, provided that the result yielded is not used |
| 141 | + // after the loop. Example: |
| 142 | + // make_tensor_ptr -> advance -> load (OK) |
| 143 | + // make_tensor_ptr -> for init_arg -> advance -> load (OK) |
| 144 | + // -> yield (OK) |
| 145 | + // make_tensor_ptr -> for init_arg -> advance -> load (OK) |
| 146 | + // -> yield -> load (NOT OK) |
| 147 | + // |
| 148 | + bool validateChain(const DefUseChain &chain) const; |
| 149 | + |
| 150 | + // Propagate \p newVal to operations in the given def-use \p chain. |
| 151 | + void propagateToUsers(Value newVal, const DefUseChain &chain, |
| 152 | + IRMapping &mapping); |
| 153 | + |
| 154 | + // Propagate \p newVal to users of \p origOp. |
| 155 | + void propagateToUsers(Value newVal, Value origVal, Operation *origOp, |
| 156 | + Operation *sentinel, IRMapping &mapping); |
| 157 | + |
| 158 | + // If \p user is not \p sentinel, propagate \p newVal to \p user. Otherwise |
| 159 | + // terminate the propagation. |
| 160 | + virtual void propagateToUser(Value newVal, Value origVal, Operation *user, |
| 161 | + Operation *sentinel, IRMapping &mapping) = 0; |
| 162 | + |
| 163 | + // Propagate \p newVal to users of \p origOp in the given \p loop. |
| 164 | + void propagateToLoop(Value newVal, Value origVal, LoopLikeOpInterface loopOp, |
| 165 | + Operation *sentinel, IRMapping &mapping); |
| 166 | +}; |
| 167 | + |
100 | 168 | } // namespace mlir::triton::intel |
101 | 169 |
|
102 | 170 | #endif // TRITON_INTEL_UTILS_DEFUSECHAIN_H |
0 commit comments