Skip to content

Commit 08efffa

Browse files
authored
[LLVMCPU] Tracks the dimension mapping for multi lowering config (#21649)
This adds `IterationDimTracker` to determine dimension mappings both within individual operations and across multiple operations. The tracker assigns a global dimension index to all loop dimensions encountered (where “local” refers to an individual operation and “global” refers to all target operations). By analyzing producer-consumer relationships of SSA values, dimensions that are considered equivalent are assigned the same global dimension index. It is currently used to improve lowering configuration propagation by identifying loop dimensions that are common across all target operations. --------- Signed-off-by: Yu-Zhewen <[email protected]>
1 parent 8b272a3 commit 08efffa

File tree

4 files changed

+270
-1
lines changed

4 files changed

+270
-1
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ iree_compiler_cc_library(
145145
"@llvm-project//mlir:FunctionInterfaces",
146146
"@llvm-project//mlir:IR",
147147
"@llvm-project//mlir:IndexToLLVM",
148+
"@llvm-project//mlir:IndexingMapOpInterface",
148149
"@llvm-project//mlir:LLVMCommonConversion",
149150
"@llvm-project//mlir:LLVMDialect",
150151
"@llvm-project//mlir:LinalgDialect",

compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ iree_cc_library(
114114
MLIRFunctionInterfaces
115115
MLIRIR
116116
MLIRIndexToLLVM
117+
MLIRIndexingMapOpInterface
117118
MLIRLLVMCommonConversion
118119
MLIRLLVMDialect
119120
MLIRLinalgDialect

compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
2020
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
2121
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
22+
#include "llvm/ADT/EquivalenceClasses.h"
2223
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/ADT/SmallVectorExtras.h"
2425
#include "llvm/ADT/TypeSwitch.h"
2526
#include "llvm/Support/CommandLine.h"
2627
#include "llvm/Support/DebugLog.h"
2728
#include "llvm/Support/InterleavedRange.h"
2829
#include "llvm/Support/MathExtras.h"
30+
#include "mlir/Analysis/TopologicalSortUtils.h"
2931
#include "mlir/Dialect/Linalg/IR/Linalg.h"
3032
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
3133
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -39,6 +41,7 @@
3941
#include "mlir/IR/OpDefinition.h"
4042
#include "mlir/IR/TypeUtilities.h"
4143
#include "mlir/Interfaces/FunctionInterfaces.h"
44+
#include "mlir/Interfaces/IndexingMapOpInterface.h"
4245
#include "mlir/Interfaces/TilingInterface.h"
4346
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4447

@@ -1067,6 +1070,219 @@ class LoweringConfigGenerator {
10671070
SmallVector<bool> vectorScalableFlags;
10681071
};
10691072

1073+
/// A helper class that tracks dimension mappings both within individual
1074+
/// operations and across multiple operations by analyzing the producer-consumer
1075+
/// relationships of SSA values. This tracking is established by assigning a
1076+
/// global dimension index to all loop dimensions encountered. Dimensions
1077+
/// sharing the same global index are considered equivalent.
1078+
class IterationDimTracker {
1079+
public:
1080+
explicit IterationDimTracker(ArrayRef<Operation *> operations)
1081+
: operations(operations.begin(), operations.end()) {
1082+
// Ensure operations are processed in topological order.
1083+
mlir::computeTopologicalSorting(this->operations);
1084+
buildDimMapping();
1085+
}
1086+
1087+
/// Returns true if the given dimension of `op` is common across all
1088+
/// operations.
1089+
bool isCommonDim(Operation *op, unsigned pos) {
1090+
assert(operationToGlobalDimMaps.contains(op));
1091+
int64_t dim = operationToGlobalDimMaps[op][pos];
1092+
for ([[maybe_unused]] auto &[_, dims] : operationToGlobalDimMaps) {
1093+
if (!llvm::is_contained(dims, dim)) {
1094+
return false;
1095+
}
1096+
}
1097+
return true;
1098+
}
1099+
1100+
private:
1101+
/// Builds and unifies dimension index mappings for all operations,
1102+
/// using producer–consumer SSA value relationships.
1103+
void buildDimMapping() {
1104+
// Tracks equivalent global dimension indices.
1105+
llvm::EquivalenceClasses<int64_t> indicesEquivalence;
1106+
// For each SSA value, maps its local dimension index to a global index.
1107+
// Value -> (local dim index -> global dim index)
1108+
llvm::SmallDenseMap<Value, SmallVector<int64_t>> valueToGlobalDimMaps;
1109+
1110+
for (Operation *op : operations) {
1111+
auto tilingOp = cast<TilingInterface>(op);
1112+
int64_t numLoops = tilingOp.getLoopIteratorTypes().size();
1113+
// Unconditionally assign new global indices, to be unified later.
1114+
for (int64_t i = 0; i < numLoops; ++i) {
1115+
int64_t globalIndex = totalLoopNum++;
1116+
indicesEquivalence.insert(globalIndex);
1117+
operationToGlobalDimMaps[op].push_back(globalIndex);
1118+
}
1119+
// The assigned global dimension indices are now unified based on
1120+
// producer–consumer SSA value relationships:
1121+
// - For operations implementing `IndexingMapOpInterface`, unify
1122+
// dimensions by iterating over their indexing maps.
1123+
// - For pack/unpack operations, use an identity mapping, since tiling
1124+
// applies to the outer (unpacked) dimensions.
1125+
// - For all other (unknown) operations, assume an identity mapping for
1126+
// any value whose rank matches the operation’s loop count.
1127+
TypeSwitch<Operation *>(op)
1128+
.Case<IndexingMapOpInterface>([&](auto op) {
1129+
propagateOnIndexingMapOp(op, indicesEquivalence,
1130+
valueToGlobalDimMaps);
1131+
})
1132+
.Case<linalg::PackOp, linalg::UnPackOp>([&](auto op) {
1133+
propagateOnPackUnpackOp(op, indicesEquivalence,
1134+
valueToGlobalDimMaps, numLoops);
1135+
})
1136+
.Default([&](auto op) {
1137+
propagateOnUnknownOp(op, indicesEquivalence, valueToGlobalDimMaps,
1138+
numLoops);
1139+
});
1140+
}
1141+
1142+
// Remap the global dimension indices in two steps:
1143+
// 1. Assign the same temporary index to all equivalent dimensions.
1144+
// 2. Convert these temporary indices to a compact, zero-based range.
1145+
auto applyReplaceMap = [&](llvm::SmallDenseMap<int64_t, int64_t> &map) {
1146+
for (auto &opEntry : operationToGlobalDimMaps) {
1147+
for (auto &dim : opEntry.second) {
1148+
dim = map.lookup(dim);
1149+
}
1150+
}
1151+
};
1152+
llvm::SmallDenseMap<int64_t, int64_t> replaceMap0, replaceMap1;
1153+
int64_t tempDimIndex = totalLoopNum;
1154+
totalLoopNum = 0;
1155+
for (auto it = indicesEquivalence.begin(); it != indicesEquivalence.end();
1156+
++it) {
1157+
if (!(*it)->isLeader()) {
1158+
continue;
1159+
}
1160+
for (auto mit = indicesEquivalence.member_begin(**it);
1161+
mit != indicesEquivalence.member_end(); ++mit) {
1162+
replaceMap0[*mit] = tempDimIndex;
1163+
}
1164+
replaceMap1[tempDimIndex] = totalLoopNum;
1165+
tempDimIndex++;
1166+
totalLoopNum++;
1167+
}
1168+
applyReplaceMap(replaceMap0);
1169+
applyReplaceMap(replaceMap1);
1170+
}
1171+
1172+
/// Ties loop dimensions together based on the operation’s indexing maps,
1173+
/// considering only simple result dimension expressions (`AffineDimExpr`).
1174+
///
1175+
/// Complex expressions (e.g., `affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2,
1176+
/// d1 * 3 + d3)>`) are ignored because they fall outside the "loop dimension"
1177+
/// concept. Such expressions describe how indices are computed within the
1178+
/// innermost loop body, but they do not directly identify which loop
1179+
/// dimensions correspond or should be tied.
1180+
void propagateOnIndexingMapOp(
1181+
IndexingMapOpInterface indexingMapOp,
1182+
llvm::EquivalenceClasses<int64_t> &indicesEquivalence,
1183+
llvm::SmallDenseMap<Value, SmallVector<int64_t>> &valueToGlobalDimMaps) {
1184+
Operation *op = indexingMapOp.getOperation();
1185+
for (OpOperand &operand : op->getOpOperands()) {
1186+
Value value = operand.get();
1187+
// Skip operands that have no known mapping from their producers.
1188+
if (!valueToGlobalDimMaps.contains(value)) {
1189+
continue;
1190+
}
1191+
AffineMap map = indexingMapOp.getMatchingIndexingMap(&operand);
1192+
for (auto [dim, expr] : llvm::enumerate(map.getResults())) {
1193+
// Stop if the current dimension exceeds the number of mapped ones.
1194+
if (dim >= valueToGlobalDimMaps[value].size()) {
1195+
break;
1196+
}
1197+
// Skip on complex expressions.
1198+
auto dimExpr = dyn_cast<AffineDimExpr>(expr);
1199+
if (!dimExpr) {
1200+
continue;
1201+
}
1202+
int64_t pos = dimExpr.getPosition();
1203+
// Unify the dimension index between the producer and the current op.
1204+
indicesEquivalence.unionSets(valueToGlobalDimMaps[value][dim],
1205+
operationToGlobalDimMaps[op][pos]);
1206+
}
1207+
}
1208+
// Propogate to results.
1209+
auto dsOp = cast<DestinationStyleOpInterface>(op);
1210+
for (OpResult result : op->getResults()) {
1211+
OpOperand *operand = dsOp.getTiedOpOperand(result);
1212+
AffineMap map = indexingMapOp.getMatchingIndexingMap(operand);
1213+
for (auto [dim, expr] : llvm::enumerate(map.getResults())) {
1214+
// Skip on complex expressions.
1215+
auto dimExpr = dyn_cast<AffineDimExpr>(expr);
1216+
if (!dimExpr) {
1217+
continue;
1218+
}
1219+
int64_t pos = dimExpr.getPosition();
1220+
valueToGlobalDimMaps[result].push_back(
1221+
operationToGlobalDimMaps[op][pos]);
1222+
}
1223+
}
1224+
}
1225+
1226+
/// Ties the dimensions of pack and unpack operations with their operands in
1227+
/// the outer (unpacked) dimensions.
1228+
void propagateOnPackUnpackOp(
1229+
Operation *op, llvm::EquivalenceClasses<int64_t> &indicesEquivalence,
1230+
llvm::SmallDenseMap<Value, SmallVector<int64_t>> &valueToGlobalDimMaps,
1231+
int64_t numLoops) {
1232+
for (OpOperand &operand : op->getOpOperands()) {
1233+
Value value = operand.get();
1234+
if (!valueToGlobalDimMaps.contains(value)) {
1235+
continue;
1236+
}
1237+
int64_t rank = cast<ShapedType>(value.getType()).getRank();
1238+
int64_t outDimSize = std::min(rank, numLoops);
1239+
for (int64_t i = 0; i < outDimSize; ++i) {
1240+
indicesEquivalence.unionSets(valueToGlobalDimMaps[value][i],
1241+
operationToGlobalDimMaps[op][i]);
1242+
}
1243+
}
1244+
// Propagate to results.
1245+
for (Value result : op->getResults()) {
1246+
valueToGlobalDimMaps[result] = operationToGlobalDimMaps[op];
1247+
}
1248+
}
1249+
1250+
/// Ties the dimensions of operations with their operands, if the operand rank
1251+
/// matches the operation’s loop count.
1252+
void propagateOnUnknownOp(
1253+
Operation *op, llvm::EquivalenceClasses<int64_t> &indicesEquivalence,
1254+
llvm::SmallDenseMap<Value, SmallVector<int64_t>> &valueToGlobalDimMaps,
1255+
int64_t numLoops) {
1256+
for (OpOperand &operand : op->getOpOperands()) {
1257+
Value value = operand.get();
1258+
if (!valueToGlobalDimMaps.contains(value) ||
1259+
numLoops != cast<ShapedType>(value.getType()).getRank()) {
1260+
continue;
1261+
}
1262+
for (int64_t i = 0; i < numLoops; ++i) {
1263+
indicesEquivalence.unionSets(valueToGlobalDimMaps[value][i],
1264+
operationToGlobalDimMaps[op][i]);
1265+
}
1266+
}
1267+
// Propagate to results.
1268+
for (Value result : op->getResults()) {
1269+
if (numLoops == cast<ShapedType>(result.getType()).getRank()) {
1270+
valueToGlobalDimMaps[result] = operationToGlobalDimMaps[op];
1271+
}
1272+
}
1273+
}
1274+
1275+
SmallVector<Operation *> operations;
1276+
// Tracks the total number of unique loop dimensions among the given set of
1277+
// operations.
1278+
int64_t totalLoopNum = 0;
1279+
// For each compute operation, maps its local loop dimension index to the
1280+
// global index. Operation -> (local dim index -> global dim
1281+
// index)
1282+
llvm::SmallDenseMap<Operation *, SmallVector<int64_t>>
1283+
operationToGlobalDimMaps;
1284+
};
1285+
10701286
/// Returns the same lowering_config attribute with the updated tile sizes and
10711287
/// scalable tile flags. The distribution tiling sizes is not set if it is
10721288
/// false.
@@ -3054,10 +3270,12 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
30543270
SmallVector<bool> commonVecScalableTileFlags = parallelVecScalableTileSizes;
30553271
SmallVector<int64_t> innerVecTileSizes(maxLoopNums, 0);
30563272
SmallVector<bool> innerVecScalableTileFlags(maxLoopNums, false);
3273+
IterationDimTracker dimTracker(computeOps);
30573274
for (auto op : computeOps) {
30583275
auto iterTypes = cast<TilingInterface>(op).getLoopIteratorTypes();
30593276
for (auto [idx, iterType] : llvm::enumerate(iterTypes)) {
3060-
if (iterType == utils::IteratorType::reduction) {
3277+
if (iterType == utils::IteratorType::reduction ||
3278+
!dimTracker.isCommonDim(op, idx)) {
30613279
innerVecTileSizes[idx] = parallelVecTileSizes[idx];
30623280
innerVecScalableTileFlags[idx] = parallelVecScalableTileSizes[idx];
30633281
commonVecTileSizes[idx] = 0;

compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,3 +2068,52 @@ func.func @complex_view_as_real() attributes {hal.executable.target = #executabl
20682068
// CHECK: func.func @complex_view_as_real()
20692069
// CHECK: linalg.generic
20702070
// CHECK-SAME: lowering_config = #[[CONFIG]]
2071+
2072+
// -----
2073+
2074+
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "+avx512f", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-linux-gnu"}>
2075+
#map = affine_map<(d0, d1) -> (d0, d1)>
2076+
#map1 = affine_map<(d0, d1) -> (d0)>
2077+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
2078+
#map3 = affine_map<(d0, d1, d2) -> (d0)>
2079+
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
2080+
#map5 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
2081+
func.func @decode_reduction_f32(%arg0: tensor<32x262144xf16>, %arg1: tensor<32xf32>, %arg2: tensor<32x16x16384xf16>, %arg3: tensor<32x16xf16>, %arg4: tensor<32x16xf16>) -> tensor<16384x32x16xf16> attributes {hal.executable.target = #executable_target_embedded_elf_x86_64_} {
2082+
%cst = arith.constant 0.000000e+00 : f32
2083+
%cst_0 = arith.constant 2.621440e+05 : f32
2084+
%cst_1 = arith.constant 9.99999997E-7 : f32
2085+
%0 = tensor.empty() : tensor<16384x32x16xf16>
2086+
%1 = tensor.empty() : tensor<32xf32>
2087+
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
2088+
%3 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<32x262144xf16>, tensor<32xf32>) outs(%2 : tensor<32xf32>) {
2089+
^bb0(%in: f16, %in_2: f32, %out: f32):
2090+
%5 = arith.extf %in : f16 to f32
2091+
%6 = arith.subf %5, %in_2 : f32
2092+
%7 = arith.mulf %6, %6 : f32
2093+
%8 = arith.addf %7, %out : f32
2094+
linalg.yield %8 : f32
2095+
} -> tensor<32xf32>
2096+
%4 = linalg.generic {indexing_maps = [#map2, #map3, #map3, #map4, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2, %arg1, %3, %arg3, %arg4 : tensor<32x16x16384xf16>, tensor<32xf32>, tensor<32xf32>, tensor<32x16xf16>, tensor<32x16xf16>) outs(%0 : tensor<16384x32x16xf16>) {
2097+
^bb0(%in: f16, %in_2: f32, %in_3: f32, %in_4: f16, %in_5: f16, %out: f16):
2098+
%5 = arith.divf %in_3, %cst_0 : f32
2099+
%6 = arith.addf %5, %cst_1 : f32
2100+
%7 = math.rsqrt %6 : f32
2101+
%8 = arith.extf %in : f16 to f32
2102+
%9 = arith.subf %8, %in_2 : f32
2103+
%10 = arith.mulf %9, %7 : f32
2104+
%11 = arith.extf %in_4 : f16 to f32
2105+
%12 = arith.mulf %10, %11 : f32
2106+
%13 = arith.extf %in_5 : f16 to f32
2107+
%14 = arith.addf %12, %13 : f32
2108+
%15 = arith.truncf %14 : f32 to f16
2109+
linalg.yield %15 : f16
2110+
} -> tensor<16384x32x16xf16>
2111+
return %4 : tensor<16384x32x16xf16>
2112+
}
2113+
// CHECK-DAG: #[[CONFIG0:.+]] = #iree_cpu.lowering_config<distribution = [4, 0], vector_common_parallel = [4, 0], vector_reduction = [0, 8]>
2114+
// CHECK-DAG: #[[CONFIG1:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [4, 0, 0], vector_inner_parallel = [0, 1, 4]>
2115+
// CHECK: func.func @decode_reduction_f32
2116+
// CHECK: linalg.generic
2117+
// CHECK-SAME: lowering_config = #[[CONFIG0]]
2118+
// CHECK: linalg.generic
2119+
// CHECK-SAME: lowering_config = #[[CONFIG1]]

0 commit comments

Comments
 (0)