20
20
#include " iree/compiler/Dialect/HAL/IR/HALTypes.h"
21
21
#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
22
22
#include " iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
23
+ #include " llvm/ADT/STLExtras.h"
23
24
#include " llvm/ADT/SmallVectorExtras.h"
24
25
#include " llvm/ADT/TypeSwitch.h"
25
26
#include " llvm/Support/CommandLine.h"
@@ -110,6 +111,7 @@ static llvm::cl::opt<bool> clEnableRiscvAggressiveDist(
110
111
llvm::cl::init(false ));
111
112
112
113
using IREE::Codegen::DispatchLoweringPassPipeline;
114
+ using IREE::CPU::TilingLevel;
113
115
114
116
// Encodes the pre-processing strategy to be applied on a Linalg operation
115
117
// before vectorization.
@@ -960,14 +962,14 @@ static void setAlwaysVectorizeSizes(linalg::LinalgOp op,
960
962
}
961
963
962
964
// / A helper class to record different level tiling sizes and generate
963
- // / corresponding IREE::CPU::LoweringConfigAttr. Only vector level supports
964
- // / scalable tile sizes for now.
965
+ // / corresponding IREE::CPU::LoweringConfigAttr for the rootOp. It implies that
966
+ // / the distribution tiling level is always set, even if tile sizes are all
967
+ // / zeros. Because a rootOp must have distribution tiling level.
968
+ // / Only vector level supports scalable tile sizes for now.
965
969
class LoweringConfigGenerator {
966
970
public:
967
- explicit LoweringConfigGenerator (Operation *op,
968
- bool emitInnerParallelList = false )
969
- : ctx(op->getContext ()), rootOp(op),
970
- emitInnerParallelList(emitInnerParallelList) {}
971
+ explicit LoweringConfigGenerator (Operation *op)
972
+ : ctx(op->getContext ()), rootOp(op) {}
971
973
972
974
void setDistributionTileSizes (ArrayRef<int64_t > tileSizes) {
973
975
assert (distTileSizes.empty () && " expected to set only once" );
@@ -990,7 +992,6 @@ class LoweringConfigGenerator {
990
992
// / existing values. By default, it will always contain distribution tile
991
993
// / sizes, unless the rootOp does not implement TilingInterface.
992
994
IREE::CPU::LoweringConfigAttr generateCPULoweringConfig () {
993
- using TilingLevel = IREE::CPU::TilingLevel;
994
995
SmallVector<NamedAttribute> items;
995
996
if (!distTileSizes.empty ()) {
996
997
appendLoweringConfigLevelAttr (items, TilingLevel::DistributionTiles,
@@ -1024,21 +1025,23 @@ class LoweringConfigGenerator {
1024
1025
parallelTileSizes, parallelScalableFlags);
1025
1026
appendLoweringConfigLevelAttr (items, TilingLevel::VectorReductionTiles,
1026
1027
reductionTileSizes, reductionScalableFlags);
1027
- if (emitInnerParallelList) {
1028
- size_t size = parallelTileSizes.size ();
1029
- appendLoweringConfigLevelAttr (
1030
- items, TilingLevel::VectorInnerParallelTiles,
1031
- SmallVector<int64_t >(size, 0 ), SmallVector<bool >(size, false ));
1032
- }
1033
1028
}
1034
1029
return IREE::CPU::LoweringConfigAttr::get (ctx, items);
1035
1030
}
1036
1031
1037
1032
private:
1033
+ // / Appends the `level` with (`tileSizes`, `scalableFlags`) tiling config to
1034
+ // / `items`, if it is not a NOP config. E.g., if all the tile sizes are zeros,
1035
+ // / it means no tiling at all. Only the distribution tiling level is
1036
+ // / unconditionally added because a root op expects the level to be present.
1038
1037
void appendLoweringConfigLevelAttr (SmallVectorImpl<NamedAttribute> &items,
1039
- IREE::CPU:: TilingLevel level,
1038
+ TilingLevel level,
1040
1039
ArrayRef<int64_t > tileSizes,
1041
1040
ArrayRef<bool > scalableFlags = {}) {
1041
+ if (level != TilingLevel::DistributionTiles &&
1042
+ llvm::all_of (tileSizes, [](int64_t v) { return v == 0 ; })) {
1043
+ return ;
1044
+ }
1042
1045
items.emplace_back (IREE::CPU::getTilingLevelName (level),
1043
1046
IREE::CPU::LoweringConfigAttr::getTilingLevelAttr (
1044
1047
ctx, tileSizes, scalableFlags));
@@ -1047,12 +1050,6 @@ class LoweringConfigGenerator {
1047
1050
MLIRContext *ctx;
1048
1051
Operation *rootOp;
1049
1052
1050
- // Generates the `IREE::CPU::TilingLevel::VectorInnerParallelTiles` tile sizes
1051
- // in the lowering config. Usually, they are zero values.
1052
- // TODO(hanchung): Remove the field once all the pipelines switch to CPU
1053
- // lowering_config. It is alive for legacy setup.
1054
- bool emitInnerParallelList = false ;
1055
-
1056
1053
// The tile sizes for distribution from the `rootOp`'s perspective.
1057
1054
SmallVector<int64_t > distTileSizes;
1058
1055
@@ -1092,8 +1089,14 @@ static IREE::Codegen::LoweringConfigAttrInterface getNewLoweringConfig(
1092
1089
1093
1090
SmallVector<NamedAttribute> newItems;
1094
1091
for (auto [level, tileSizes, scalableFlags] : tilingInfo) {
1095
- if (!setDistributionConfig &&
1096
- level == IREE::CPU::TilingLevel::DistributionTiles) {
1092
+ if (!setDistributionConfig && level == TilingLevel::DistributionTiles) {
1093
+ continue ;
1094
+ }
1095
+ // Distribution tile sizes is a must for rootOp, because it is the
1096
+ // definition of root op. An operation that has distribution tile sizes is
1097
+ // the root op. Other level can be dropped if all the tile sizes are zeros.
1098
+ if (level != TilingLevel::DistributionTiles &&
1099
+ llvm::all_of (tileSizes, [](int64_t val) { return val == 0 ; })) {
1097
1100
continue ;
1098
1101
}
1099
1102
newItems.emplace_back (IREE::CPU::getTilingLevelName (level),
@@ -1155,7 +1158,7 @@ static LogicalResult setMatmulPeelingRootConfig(
1155
1158
inputVecScalableTileFlags.end ());
1156
1159
vectorScalableFlags.back () = false ;
1157
1160
1158
- LoweringConfigGenerator generator (op, /* emitInnerParallelList= */ true );
1161
+ LoweringConfigGenerator generator (op);
1159
1162
generator.setDistributionTileSizes (distTileSizes);
1160
1163
generator.setCacheTileSizes (cacheTileSizes);
1161
1164
generator.setVectorTileSizes (vecTileSizes, vectorScalableFlags);
@@ -1206,7 +1209,7 @@ static LogicalResult setMatmulRootConfig(
1206
1209
}
1207
1210
limitVectorTileSizes (cast<linalg::LinalgOp>(op.getOperation ()), vecTileSizes);
1208
1211
1209
- LoweringConfigGenerator generator (op, /* emitInnerParallelList= */ true );
1212
+ LoweringConfigGenerator generator (op);
1210
1213
generator.setDistributionTileSizes (distTileSizes);
1211
1214
generator.setVectorTileSizes (vecTileSizes, vecScalableFlags);
1212
1215
IREE::CPU::LoweringConfigAttr loweringConfig =
@@ -2085,10 +2088,9 @@ setDefaultGenericOpRootConfig(mlir::FunctionOpInterface entryPointFn,
2085
2088
// If there are no loops, there is nothing to do.
2086
2089
unsigned numLoops = genericOp.getNumLoops ();
2087
2090
if (numLoops == 0 ) {
2091
+ LoweringConfigGenerator generator (genericOp);
2088
2092
return setOpConfigAndEntryPointFnTranslation (
2089
- entryPointFn, genericOp,
2090
- IREE::CPU::LoweringConfigAttr::get (genericOp.getContext (),
2091
- SmallVector<NamedAttribute>()),
2093
+ entryPointFn, genericOp, generator.generateCPULoweringConfig (),
2092
2094
DispatchLoweringPassPipeline::CPUDefault);
2093
2095
}
2094
2096
@@ -2113,7 +2115,7 @@ setDefaultGenericOpRootConfig(mlir::FunctionOpInterface entryPointFn,
2113
2115
distConfig.maxTileSizes , vecPreProcStrategy, vecTileSizes);
2114
2116
limitVectorTileSizes (genericOp, vecTileSizes);
2115
2117
2116
- LoweringConfigGenerator generator (genericOp, /* emitInnerParallelList= */ true );
2118
+ LoweringConfigGenerator generator (genericOp);
2117
2119
generator.setDistributionTileSizes (distTileSizes);
2118
2120
generator.setVectorTileSizes (vecTileSizes);
2119
2121
IREE::CPU::LoweringConfigAttr loweringConfig =
@@ -2267,7 +2269,7 @@ setTransposeLikeOpRootConfig(mlir::FunctionOpInterface entryPointFn,
2267
2269
SmallVector<int64_t > distTileSizes =
2268
2270
getDefaultDistributedLevelTileSizes (genericOp, distConfig);
2269
2271
2270
- LoweringConfigGenerator generator (genericOp, /* emitInnerParallelList= */ true );
2272
+ LoweringConfigGenerator generator (genericOp);
2271
2273
generator.setDistributionTileSizes (distTileSizes);
2272
2274
generator.setVectorTileSizes (vecSizes, vecScalableDims);
2273
2275
IREE::CPU::LoweringConfigAttr loweringConfig =
@@ -2346,7 +2348,7 @@ static LogicalResult setElementwiseGenericOpRootConfig(
2346
2348
vecPreProcStrategy == VectorPreProcStrategy::Masking);
2347
2349
}
2348
2350
2349
- LoweringConfigGenerator generator (genericOp, /* emitInnerParallelList= */ true );
2351
+ LoweringConfigGenerator generator (genericOp);
2350
2352
generator.setDistributionTileSizes (distTileSizes);
2351
2353
generator.setVectorTileSizes (vecTileSizes);
2352
2354
IREE::CPU::LoweringConfigAttr loweringConfig =
@@ -2603,7 +2605,7 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
2603
2605
2604
2606
SmallVector<int64_t > distTileSizes =
2605
2607
getDefaultDistributedLevelTileSizes (padOp, distConfig);
2606
- LoweringConfigGenerator generator (padOp, /* emitInnerParallelList= */ true );
2608
+ LoweringConfigGenerator generator (padOp);
2607
2609
generator.setDistributionTileSizes (distTileSizes);
2608
2610
generator.setVectorTileSizes (distConfig.vectorSizeHints );
2609
2611
IREE::CPU::LoweringConfigAttr loweringConfig =
@@ -2879,7 +2881,7 @@ adjustTileSizesForGenericOp(mlir::FunctionOpInterface entryPointFn,
2879
2881
// / `level`, if it is present. Otherwise, adds a new item to the vector.
2880
2882
static void updateOrAddTilingLevelInfo (
2881
2883
SmallVectorImpl<IREE::CPU::LoweringConfigLevelInfo> &tilingInfo,
2882
- IREE::CPU:: TilingLevel level, ArrayRef<int64_t > tileSizes,
2884
+ TilingLevel level, ArrayRef<int64_t > tileSizes,
2883
2885
ArrayRef<bool > scalableFlags) {
2884
2886
for (IREE::CPU::LoweringConfigLevelInfo &info : tilingInfo) {
2885
2887
if (info.level == level) {
@@ -2946,16 +2948,18 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
2946
2948
}
2947
2949
2948
2950
auto rootLoweringConfig = getLoweringConfig (rootOperation);
2949
- std::unique_ptr<TilingConfig> tilingConfig =
2950
- TilingConfig::create (rootLoweringConfig);
2951
2951
SmallVector<int64_t > distTileSizes, parallelVecTileSizes;
2952
2952
SmallVector<bool > distScalableTileSizes, parallelVecScalableTileSizes;
2953
- if (tilingConfig->getNumTilingLevels () > 0 ) {
2954
- distTileSizes = tilingConfig->getDistributionTileSizes ();
2955
- }
2956
- if (tilingConfig->getNumTilingLevels () > 1 ) {
2957
- std::tie (parallelVecTileSizes, parallelVecScalableTileSizes) =
2958
- tilingConfig->getVectorCommonParallelSizes ();
2953
+ assert (rootLoweringConfig.hasWorkgroupTilingLevel ());
2954
+ distTileSizes = rootLoweringConfig.getWorkgroupTileSizes ();
2955
+ if (rootLoweringConfig.hasTilingLevel (
2956
+ TilingLevel::VectorCommonParallelTiles)) {
2957
+ auto attr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
2958
+ rootLoweringConfig.getTilingLevelAttr (
2959
+ TilingLevel::VectorCommonParallelTiles));
2960
+ parallelVecTileSizes.assign (attr.getSizes ().begin (), attr.getSizes ().end ());
2961
+ parallelVecScalableTileSizes.assign (attr.getScalableFlags ().begin (),
2962
+ attr.getScalableFlags ().end ());
2959
2963
}
2960
2964
2961
2965
size_t maxLoopNums = 0 ;
@@ -3080,6 +3084,9 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
3080
3084
}
3081
3085
3082
3086
// Set the lowering configs with new tile sizes.
3087
+ // TODO(hanchung): Deprecate TilingConfig from the file.
3088
+ std::unique_ptr<TilingConfig> tilingConfig =
3089
+ TilingConfig::create (rootLoweringConfig);
3083
3090
for (auto op : computeOps) {
3084
3091
int numLoops = cast<TilingInterface>(op).getLoopIteratorTypes ().size ();
3085
3092
SmallVector<IREE::CPU::LoweringConfigLevelInfo> newTilingInfo;
0 commit comments