|
19 | 19 | #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" |
20 | 20 | #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
21 | 21 | #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" |
| 22 | +#include "llvm/ADT/EquivalenceClasses.h" |
22 | 23 | #include "llvm/ADT/STLExtras.h" |
23 | 24 | #include "llvm/ADT/SmallVectorExtras.h" |
24 | 25 | #include "llvm/ADT/TypeSwitch.h" |
25 | 26 | #include "llvm/Support/CommandLine.h" |
26 | 27 | #include "llvm/Support/DebugLog.h" |
27 | 28 | #include "llvm/Support/InterleavedRange.h" |
28 | 29 | #include "llvm/Support/MathExtras.h" |
| 30 | +#include "mlir/Analysis/TopologicalSortUtils.h" |
29 | 31 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
30 | 32 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
31 | 33 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
|
39 | 41 | #include "mlir/IR/OpDefinition.h" |
40 | 42 | #include "mlir/IR/TypeUtilities.h" |
41 | 43 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 44 | +#include "mlir/Interfaces/IndexingMapOpInterface.h" |
42 | 45 | #include "mlir/Interfaces/TilingInterface.h" |
43 | 46 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
44 | 47 |
|
@@ -1067,6 +1070,219 @@ class LoweringConfigGenerator { |
1067 | 1070 | SmallVector<bool> vectorScalableFlags; |
1068 | 1071 | }; |
1069 | 1072 |
|
| 1073 | +/// A helper class that tracks dimension mappings both within individual |
| 1074 | +/// operations and across multiple operations by analyzing the producer-consumer |
| 1075 | +/// relationships of SSA values. This tracking is established by assigning a |
| 1076 | +/// global dimension index to all loop dimensions encountered. Dimensions |
| 1077 | +/// sharing the same global index are considered equivalent. |
| 1078 | +class IterationDimTracker { |
| 1079 | +public: |
| 1080 | + explicit IterationDimTracker(ArrayRef<Operation *> operations) |
| 1081 | + : operations(operations.begin(), operations.end()) { |
| 1082 | + // Ensure operations are processed in topological order. |
| 1083 | + mlir::computeTopologicalSorting(this->operations); |
| 1084 | + buildDimMapping(); |
| 1085 | + } |
| 1086 | + |
| 1087 | + /// Returns true if the given dimension of `op` is common across all |
| 1088 | + /// operations. |
| 1089 | + bool isCommonDim(Operation *op, unsigned pos) { |
| 1090 | + assert(operationToGlobalDimMaps.contains(op)); |
| 1091 | + int64_t dim = operationToGlobalDimMaps[op][pos]; |
| 1092 | + for ([[maybe_unused]] auto &[_, dims] : operationToGlobalDimMaps) { |
| 1093 | + if (!llvm::is_contained(dims, dim)) { |
| 1094 | + return false; |
| 1095 | + } |
| 1096 | + } |
| 1097 | + return true; |
| 1098 | + } |
| 1099 | + |
| 1100 | +private: |
| 1101 | + /// Builds and unifies dimension index mappings for all operations, |
| 1102 | + /// using producer–consumer SSA value relationships. |
| 1103 | + void buildDimMapping() { |
| 1104 | + // Tracks equivalent global dimension indices. |
| 1105 | + llvm::EquivalenceClasses<int64_t> indicesEquivalence; |
| 1106 | + // For each SSA value, maps its local dimension index to a global index. |
| 1107 | + // Value -> (local dim index -> global dim index) |
| 1108 | + llvm::SmallDenseMap<Value, SmallVector<int64_t>> valueToGlobalDimMaps; |
| 1109 | + |
| 1110 | + for (Operation *op : operations) { |
| 1111 | + auto tilingOp = cast<TilingInterface>(op); |
| 1112 | + int64_t numLoops = tilingOp.getLoopIteratorTypes().size(); |
| 1113 | + // Unconditionally assign new global indices, to be unified later. |
| 1114 | + for (int64_t i = 0; i < numLoops; ++i) { |
| 1115 | + int64_t globalIndex = totalLoopNum++; |
| 1116 | + indicesEquivalence.insert(globalIndex); |
| 1117 | + operationToGlobalDimMaps[op].push_back(globalIndex); |
| 1118 | + } |
| 1119 | + // The assigned global dimension indices are now unified based on |
| 1120 | + // producer–consumer SSA value relationships: |
| 1121 | + // - For operations implementing `IndexingMapOpInterface`, unify |
| 1122 | + // dimensions by iterating over their indexing maps. |
| 1123 | + // - For pack/unpack operations, use an identity mapping, since tiling |
| 1124 | + // applies to the outer (unpacked) dimensions. |
| 1125 | + // - For all other (unknown) operations, assume an identity mapping for |
| 1126 | + // any value whose rank matches the operation’s loop count. |
| 1127 | + TypeSwitch<Operation *>(op) |
| 1128 | + .Case<IndexingMapOpInterface>([&](auto op) { |
| 1129 | + propagateOnIndexingMapOp(op, indicesEquivalence, |
| 1130 | + valueToGlobalDimMaps); |
| 1131 | + }) |
| 1132 | + .Case<linalg::PackOp, linalg::UnPackOp>([&](auto op) { |
| 1133 | + propagateOnPackUnpackOp(op, indicesEquivalence, |
| 1134 | + valueToGlobalDimMaps, numLoops); |
| 1135 | + }) |
| 1136 | + .Default([&](auto op) { |
| 1137 | + propagateOnUnknownOp(op, indicesEquivalence, valueToGlobalDimMaps, |
| 1138 | + numLoops); |
| 1139 | + }); |
| 1140 | + } |
| 1141 | + |
| 1142 | + // Remap the global dimension indices in two steps: |
| 1143 | + // 1. Assign the same temporary index to all equivalent dimensions. |
| 1144 | + // 2. Convert these temporary indices to a compact, zero-based range. |
| 1145 | + auto applyReplaceMap = [&](llvm::SmallDenseMap<int64_t, int64_t> &map) { |
| 1146 | + for (auto &opEntry : operationToGlobalDimMaps) { |
| 1147 | + for (auto &dim : opEntry.second) { |
| 1148 | + dim = map.lookup(dim); |
| 1149 | + } |
| 1150 | + } |
| 1151 | + }; |
| 1152 | + llvm::SmallDenseMap<int64_t, int64_t> replaceMap0, replaceMap1; |
| 1153 | + int64_t tempDimIndex = totalLoopNum; |
| 1154 | + totalLoopNum = 0; |
| 1155 | + for (auto it = indicesEquivalence.begin(); it != indicesEquivalence.end(); |
| 1156 | + ++it) { |
| 1157 | + if (!(*it)->isLeader()) { |
| 1158 | + continue; |
| 1159 | + } |
| 1160 | + for (auto mit = indicesEquivalence.member_begin(**it); |
| 1161 | + mit != indicesEquivalence.member_end(); ++mit) { |
| 1162 | + replaceMap0[*mit] = tempDimIndex; |
| 1163 | + } |
| 1164 | + replaceMap1[tempDimIndex] = totalLoopNum; |
| 1165 | + tempDimIndex++; |
| 1166 | + totalLoopNum++; |
| 1167 | + } |
| 1168 | + applyReplaceMap(replaceMap0); |
| 1169 | + applyReplaceMap(replaceMap1); |
| 1170 | + } |
| 1171 | + |
| 1172 | + /// Ties loop dimensions together based on the operation’s indexing maps, |
| 1173 | + /// considering only simple result dimension expressions (`AffineDimExpr`). |
| 1174 | + /// |
| 1175 | + /// Complex expressions (e.g., `affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, |
| 1176 | + /// d1 * 3 + d3)>`) are ignored because they fall outside the "loop dimension" |
| 1177 | + /// concept. Such expressions describe how indices are computed within the |
| 1178 | + /// innermost loop body, but they do not directly identify which loop |
| 1179 | + /// dimensions correspond or should be tied. |
| 1180 | + void propagateOnIndexingMapOp( |
| 1181 | + IndexingMapOpInterface indexingMapOp, |
| 1182 | + llvm::EquivalenceClasses<int64_t> &indicesEquivalence, |
| 1183 | + llvm::SmallDenseMap<Value, SmallVector<int64_t>> &valueToGlobalDimMaps) { |
| 1184 | + Operation *op = indexingMapOp.getOperation(); |
| 1185 | + for (OpOperand &operand : op->getOpOperands()) { |
| 1186 | + Value value = operand.get(); |
| 1187 | + // Skip operands that have no known mapping from their producers. |
| 1188 | + if (!valueToGlobalDimMaps.contains(value)) { |
| 1189 | + continue; |
| 1190 | + } |
| 1191 | + AffineMap map = indexingMapOp.getMatchingIndexingMap(&operand); |
| 1192 | + for (auto [dim, expr] : llvm::enumerate(map.getResults())) { |
| 1193 | + // Stop if the current dimension exceeds the number of mapped ones. |
| 1194 | + if (dim >= valueToGlobalDimMaps[value].size()) { |
| 1195 | + break; |
| 1196 | + } |
| 1197 | + // Skip on complex expressions. |
| 1198 | + auto dimExpr = dyn_cast<AffineDimExpr>(expr); |
| 1199 | + if (!dimExpr) { |
| 1200 | + continue; |
| 1201 | + } |
| 1202 | + int64_t pos = dimExpr.getPosition(); |
| 1203 | + // Unify the dimension index between the producer and the current op. |
| 1204 | + indicesEquivalence.unionSets(valueToGlobalDimMaps[value][dim], |
| 1205 | + operationToGlobalDimMaps[op][pos]); |
| 1206 | + } |
| 1207 | + } |
| 1208 | + // Propogate to results. |
| 1209 | + auto dsOp = cast<DestinationStyleOpInterface>(op); |
| 1210 | + for (OpResult result : op->getResults()) { |
| 1211 | + OpOperand *operand = dsOp.getTiedOpOperand(result); |
| 1212 | + AffineMap map = indexingMapOp.getMatchingIndexingMap(operand); |
| 1213 | + for (auto [dim, expr] : llvm::enumerate(map.getResults())) { |
| 1214 | + // Skip on complex expressions. |
| 1215 | + auto dimExpr = dyn_cast<AffineDimExpr>(expr); |
| 1216 | + if (!dimExpr) { |
| 1217 | + continue; |
| 1218 | + } |
| 1219 | + int64_t pos = dimExpr.getPosition(); |
| 1220 | + valueToGlobalDimMaps[result].push_back( |
| 1221 | + operationToGlobalDimMaps[op][pos]); |
| 1222 | + } |
| 1223 | + } |
| 1224 | + } |
| 1225 | + |
| 1226 | + /// Ties the dimensions of pack and unpack operations with their operands in |
| 1227 | + /// the outer (unpacked) dimensions. |
| 1228 | + void propagateOnPackUnpackOp( |
| 1229 | + Operation *op, llvm::EquivalenceClasses<int64_t> &indicesEquivalence, |
| 1230 | + llvm::SmallDenseMap<Value, SmallVector<int64_t>> &valueToGlobalDimMaps, |
| 1231 | + int64_t numLoops) { |
| 1232 | + for (OpOperand &operand : op->getOpOperands()) { |
| 1233 | + Value value = operand.get(); |
| 1234 | + if (!valueToGlobalDimMaps.contains(value)) { |
| 1235 | + continue; |
| 1236 | + } |
| 1237 | + int64_t rank = cast<ShapedType>(value.getType()).getRank(); |
| 1238 | + int64_t outDimSize = std::min(rank, numLoops); |
| 1239 | + for (int64_t i = 0; i < outDimSize; ++i) { |
| 1240 | + indicesEquivalence.unionSets(valueToGlobalDimMaps[value][i], |
| 1241 | + operationToGlobalDimMaps[op][i]); |
| 1242 | + } |
| 1243 | + } |
| 1244 | + // Propagate to results. |
| 1245 | + for (Value result : op->getResults()) { |
| 1246 | + valueToGlobalDimMaps[result] = operationToGlobalDimMaps[op]; |
| 1247 | + } |
| 1248 | + } |
| 1249 | + |
| 1250 | + /// Ties the dimensions of operations with their operands, if the operand rank |
| 1251 | + /// matches the operation’s loop count. |
| 1252 | + void propagateOnUnknownOp( |
| 1253 | + Operation *op, llvm::EquivalenceClasses<int64_t> &indicesEquivalence, |
| 1254 | + llvm::SmallDenseMap<Value, SmallVector<int64_t>> &valueToGlobalDimMaps, |
| 1255 | + int64_t numLoops) { |
| 1256 | + for (OpOperand &operand : op->getOpOperands()) { |
| 1257 | + Value value = operand.get(); |
| 1258 | + if (!valueToGlobalDimMaps.contains(value) || |
| 1259 | + numLoops != cast<ShapedType>(value.getType()).getRank()) { |
| 1260 | + continue; |
| 1261 | + } |
| 1262 | + for (int64_t i = 0; i < numLoops; ++i) { |
| 1263 | + indicesEquivalence.unionSets(valueToGlobalDimMaps[value][i], |
| 1264 | + operationToGlobalDimMaps[op][i]); |
| 1265 | + } |
| 1266 | + } |
| 1267 | + // Propagate to results. |
| 1268 | + for (Value result : op->getResults()) { |
| 1269 | + if (numLoops == cast<ShapedType>(result.getType()).getRank()) { |
| 1270 | + valueToGlobalDimMaps[result] = operationToGlobalDimMaps[op]; |
| 1271 | + } |
| 1272 | + } |
| 1273 | + } |
| 1274 | + |
| 1275 | + SmallVector<Operation *> operations; |
| 1276 | + // Tracks the total number of unique loop dimensions among the given set of |
| 1277 | + // operations. |
| 1278 | + int64_t totalLoopNum = 0; |
| 1279 | + // For each compute operation, maps its local loop dimension index to the |
| 1280 | + // global index. Operation -> (local dim index -> global dim |
| 1281 | + // index) |
| 1282 | + llvm::SmallDenseMap<Operation *, SmallVector<int64_t>> |
| 1283 | + operationToGlobalDimMaps; |
| 1284 | +}; |
| 1285 | + |
1070 | 1286 | /// Returns the same lowering_config attribute with the updated tile sizes and |
1071 | 1287 | /// scalable tile flags. The distribution tiling sizes is not set if it is |
1072 | 1288 | /// false. |
@@ -3054,10 +3270,12 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn, |
3054 | 3270 | SmallVector<bool> commonVecScalableTileFlags = parallelVecScalableTileSizes; |
3055 | 3271 | SmallVector<int64_t> innerVecTileSizes(maxLoopNums, 0); |
3056 | 3272 | SmallVector<bool> innerVecScalableTileFlags(maxLoopNums, false); |
| 3273 | + IterationDimTracker dimTracker(computeOps); |
3057 | 3274 | for (auto op : computeOps) { |
3058 | 3275 | auto iterTypes = cast<TilingInterface>(op).getLoopIteratorTypes(); |
3059 | 3276 | for (auto [idx, iterType] : llvm::enumerate(iterTypes)) { |
3060 | | - if (iterType == utils::IteratorType::reduction) { |
| 3277 | + if (iterType == utils::IteratorType::reduction || |
| 3278 | + !dimTracker.isCommonDim(op, idx)) { |
3061 | 3279 | innerVecTileSizes[idx] = parallelVecTileSizes[idx]; |
3062 | 3280 | innerVecScalableTileFlags[idx] = parallelVecScalableTileSizes[idx]; |
3063 | 3281 | commonVecTileSizes[idx] = 0; |
|
0 commit comments