@@ -116,7 +116,6 @@ static llvm::cl::opt<bool> clEnableRiscvAggressiveDist(
116116 llvm::cl::init(false ));
117117
118118using IREE::Codegen::DispatchLoweringPassPipeline;
119- using IREE::CPU::TilingLevel;
120119
121120// Encodes the pre-processing strategy to be applied on a Linalg operation
122121// before vectorization.
@@ -1003,21 +1002,23 @@ class LoweringConfigGenerator {
10031002 IREE::CPU::LoweringConfigAttr generateCPULoweringConfig () {
10041003 SmallVector<NamedAttribute> items;
10051004 if (!distTileSizes.empty ()) {
1006- appendLoweringConfigLevelAttr (items, TilingLevel::DistributionTiles,
1007- distTileSizes);
1005+ appendLoweringConfigLevelAttr (
1006+ items, IREE::CPU::TilingLevel::DistributionTiles, distTileSizes);
10081007 } else if (auto op = dyn_cast<TilingInterface>(rootOp)) {
10091008 size_t numTilingDims = op.getLoopIteratorTypes ().size ();
1010- appendLoweringConfigLevelAttr (items, TilingLevel::DistributionTiles,
1009+ appendLoweringConfigLevelAttr (items,
1010+ IREE::CPU::TilingLevel::DistributionTiles,
10111011 SmallVector<int64_t >(numTilingDims, 0 ));
10121012 }
10131013 if (!cacheTileSizes.empty ()) {
10141014 SmallVector<int64_t > parallelTileSizes = cacheTileSizes;
10151015 SmallVector<int64_t > reductionTileSizes;
10161016 splitParallelAndReductionTiles (rootOp, parallelTileSizes,
10171017 reductionTileSizes);
1018- appendLoweringConfigLevelAttr (items, TilingLevel::CacheParallelTiles,
1019- parallelTileSizes);
1020- appendLoweringConfigLevelAttr (items, TilingLevel::CacheReductionTiles,
1018+ appendLoweringConfigLevelAttr (
1019+ items, IREE::CPU::TilingLevel::CacheParallelTiles, parallelTileSizes);
1020+ appendLoweringConfigLevelAttr (items,
1021+ IREE::CPU::TilingLevel::CacheReductionTiles,
10211022 reductionTileSizes);
10221023 }
10231024 if (!vectorTileSizes.empty ()) {
@@ -1029,11 +1030,12 @@ class LoweringConfigGenerator {
10291030 splitParallelAndReductionTiles (rootOp, parallelTileSizes,
10301031 reductionTileSizes, ¶llelScalableFlags,
10311032 &reductionScalableFlags);
1032- appendLoweringConfigLevelAttr (items,
1033- TilingLevel::VectorCommonParallelTiles,
1034- parallelTileSizes, parallelScalableFlags);
1035- appendLoweringConfigLevelAttr (items, TilingLevel::VectorReductionTiles,
1036- reductionTileSizes, reductionScalableFlags);
1033+ appendLoweringConfigLevelAttr (
1034+ items, IREE::CPU::TilingLevel::VectorCommonParallelTiles,
1035+ parallelTileSizes, parallelScalableFlags);
1036+ appendLoweringConfigLevelAttr (
1037+ items, IREE::CPU::TilingLevel::VectorReductionTiles,
1038+ reductionTileSizes, reductionScalableFlags);
10371039 }
10381040 return IREE::CPU::LoweringConfigAttr::get (ctx, items);
10391041 }
@@ -1044,10 +1046,10 @@ class LoweringConfigGenerator {
10441046 // / it means no tiling at all. Only the distribution tiling level is
10451047 // / unconditionally added because a root op expects the level to be present.
10461048 void appendLoweringConfigLevelAttr (SmallVectorImpl<NamedAttribute> &items,
1047- TilingLevel level,
1049+ IREE::CPU:: TilingLevel level,
10481050 ArrayRef<int64_t > tileSizes,
10491051 ArrayRef<bool > scalableFlags = {}) {
1050- if (level != TilingLevel::DistributionTiles &&
1052+ if (level != IREE::CPU:: TilingLevel::DistributionTiles &&
10511053 llvm::all_of (tileSizes, [](int64_t v) { return v == 0 ; })) {
10521054 return ;
10531055 }
@@ -1309,13 +1311,14 @@ getNewLoweringConfig(MLIRContext *ctx,
13091311 bool setDistributionConfig) {
13101312 SmallVector<NamedAttribute> newItems;
13111313 for (auto [level, tileSizes, scalableFlags] : tilingInfo) {
1312- if (!setDistributionConfig && level == TilingLevel::DistributionTiles) {
1314+ if (!setDistributionConfig &&
1315+ level == IREE::CPU::TilingLevel::DistributionTiles) {
13131316 continue ;
13141317 }
13151318 // Distribution tile sizes is a must for rootOp, because it is the
13161319 // definition of root op. An operation that has distribution tile sizes is
13171320 // the root op. Other level can be dropped if all the tile sizes are zeros.
1318- if (level != TilingLevel::DistributionTiles &&
1321+ if (level != IREE::CPU:: TilingLevel::DistributionTiles &&
13191322 llvm::all_of (tileSizes, [](int64_t val) { return val == 0 ; })) {
13201323 continue ;
13211324 }
@@ -3096,8 +3099,10 @@ class MultiLoweringConfigGenerator {
30963099 IterationDimTracker dimTracker;
30973100 // For each tiling level, store per-dimension tiling information.
30983101 // TilingLevel -> (global loop dimension index -> tile size / scalable flag)
3099- llvm::SmallDenseMap<TilingLevel, SmallVector<int64_t >> globalTileSizes;
3100- llvm::SmallDenseMap<TilingLevel, SmallVector<bool >> globalScalableTileFlags;
3102+ llvm::SmallDenseMap<IREE::CPU::TilingLevel, SmallVector<int64_t >>
3103+ globalTileSizes;
3104+ llvm::SmallDenseMap<IREE::CPU::TilingLevel, SmallVector<bool >>
3105+ globalScalableTileFlags;
31013106
31023107 // Store the vector parallel tile sizes preferred by non-root operations.
31033108 // Operation -> (global loop dimension index -> tile size)
@@ -3134,26 +3139,26 @@ MultiLoweringConfigGenerator::create(Operation *rootOperation,
31343139void MultiLoweringConfigGenerator::loadRootLoweringConfig () {
31353140 const int64_t totalLoopNum = dimTracker.getTotalLoopNum ();
31363141
3137- auto loadTilingLevel = [&](TilingLevel level) {
3142+ auto loadTilingLevel = [&](IREE::CPU:: TilingLevel level) {
31383143 SmallVector<int64_t > sizes;
31393144 SmallVector<bool > flags;
3140- if (level == TilingLevel::DistributionTiles) {
3145+ if (level == IREE::CPU:: TilingLevel::DistributionTiles) {
31413146 assert (rootLoweringConfig.hasWorkgroupTilingLevel () &&
31423147 " Expected root lowering config to have workgroup tiling level." );
31433148 sizes = rootLoweringConfig.getWorkgroupTileSizes ();
31443149 flags.resize (sizes.size (), false );
3145- } else if (level == TilingLevel::VectorCommonParallelTiles) {
3146- if (rootLoweringConfig.hasTilingLevel (level)) {
3150+ } else if (level == IREE::CPU:: TilingLevel::VectorCommonParallelTiles) {
3151+ if (rootLoweringConfig.hasTilingLevel (llvm::to_underlying ( level) )) {
31473152 auto attr = llvm::cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
3148- rootLoweringConfig.getTilingLevelAttr (level));
3153+ rootLoweringConfig.getTilingLevelAttr (llvm::to_underlying ( level) ));
31493154 sizes.assign (attr.getSizes ());
31503155 // Only `VectorCommonParallel` has scalable flags.
31513156 flags.assign (attr.getScalableFlags ());
31523157 }
31533158 } else {
3154- if (rootLoweringConfig.hasTilingLevel (level)) {
3155- sizes =
3156- rootLoweringConfig. getStaticTilingLevelSizes (level, rootOperation);
3159+ if (rootLoweringConfig.hasTilingLevel (llvm::to_underlying ( level) )) {
3160+ sizes = rootLoweringConfig. getStaticTilingLevelSizes (
3161+ llvm::to_underlying (level) , rootOperation);
31573162 flags.resize (sizes.size (), false );
31583163 }
31593164 }
@@ -3180,9 +3185,8 @@ void MultiLoweringConfigGenerator::loadRootLoweringConfig() {
31803185 };
31813186
31823187 // Load all tiling levels.
3183- for (int i = 0 , e = static_cast <int >(TilingLevel::MaxNumTileLevels); i < e;
3184- ++i) {
3185- loadTilingLevel (static_cast <TilingLevel>(i));
3188+ for (auto i : IREE::CPU::getTilingLevelsAsInts ()) {
3189+ loadTilingLevel (static_cast <IREE::CPU::TilingLevel>(i));
31863190 }
31873191}
31883192
@@ -3211,7 +3215,7 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32113215 ArrayRef<int64_t > rootOpGlobalDims =
32123216 dimTracker.getAllGlobalDimIdx (rootOperation);
32133217 auto adjust = [&](Operation *op, ArrayRef<int64_t > vecTileSize,
3214- TilingLevel level,
3218+ IREE::CPU:: TilingLevel level,
32153219 llvm::function_ref<int64_t (int64_t , int64_t )> updater) {
32163220 for (auto [pos, size] : llvm::enumerate (vecTileSize)) {
32173221 int64_t globalDimIdx = dimTracker.getGlobalDimIdx (op, pos);
@@ -3231,8 +3235,8 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32313235 if (isa<linalg::PackOp>(op)) {
32323236 // For pack op, align the distribution tile size and overwrite the
32333237 // vector parallel tile size.
3234- adjust (op, vecTileSize, TilingLevel::DistributionTiles, align);
3235- adjust (op, vecTileSize, TilingLevel::VectorCommonParallelTiles,
3238+ adjust (op, vecTileSize, IREE::CPU:: TilingLevel::DistributionTiles, align);
3239+ adjust (op, vecTileSize, IREE::CPU:: TilingLevel::VectorCommonParallelTiles,
32363240 overwrite);
32373241 } else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
32383242 // For unpack op, just overwrite the vector parallel tile size.
@@ -3261,7 +3265,7 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32613265 adjustedTileSize[dimExpr.getPosition ()] = tileSize;
32623266 }
32633267 adjust (linalgOp.getOperation (), adjustedTileSize,
3264- TilingLevel::VectorCommonParallelTiles, overwrite);
3268+ IREE::CPU:: TilingLevel::VectorCommonParallelTiles, overwrite);
32653269 }
32663270 }
32673271
@@ -3285,8 +3289,8 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32853289 if (elementTypeSize == 1 ) {
32863290 SmallVector<int64_t > vecTileSize (rootOpGlobalDims.size (), 0 );
32873291 vecTileSize.back () = 8 ;
3288- adjust (rootOperation, vecTileSize, TilingLevel::VectorCommonParallelTiles,
3289- align);
3292+ adjust (rootOperation, vecTileSize,
3293+ IREE::CPU::TilingLevel::VectorCommonParallelTiles, align);
32903294 }
32913295 }
32923296}
@@ -3306,7 +3310,8 @@ void MultiLoweringConfigGenerator::fillTileSizesWithNonRootOps() {
33063310 for (auto [pos, size] : llvm::enumerate (vecTileSize)) {
33073311 int64_t globalDimIdx = dimTracker.getGlobalDimIdx (op, pos);
33083312 int64_t &tile =
3309- globalTileSizes[TilingLevel::VectorCommonParallelTiles][globalDimIdx];
3313+ globalTileSizes[IREE::CPU::TilingLevel::VectorCommonParallelTiles]
3314+ [globalDimIdx];
33103315 // Only set the tile size if it hasn't been assigned yet.
33113316 if (tile == 0 && size > 0 ) {
33123317 tile = size;
@@ -3329,10 +3334,12 @@ void MultiLoweringConfigGenerator::getGenericReductionTileSizes() {
33293334 continue ;
33303335 }
33313336 int64_t globalDimIdx = dimTracker.getGlobalDimIdx (op, pos);
3332- globalTileSizes[TilingLevel::VectorReductionTiles][globalDimIdx] = size;
3333- globalScalableTileFlags[TilingLevel::VectorReductionTiles][globalDimIdx] =
3334- globalScalableTileFlags[TilingLevel::VectorCommonParallelTiles]
3335- [globalDimIdx];
3337+ globalTileSizes[IREE::CPU::TilingLevel::VectorReductionTiles]
3338+ [globalDimIdx] = size;
3339+ globalScalableTileFlags
3340+ [IREE::CPU::TilingLevel::VectorReductionTiles]
3341+ [globalDimIdx] = globalScalableTileFlags
3342+ [IREE::CPU::TilingLevel::VectorCommonParallelTiles][globalDimIdx];
33363343 }
33373344 }
33383345}
@@ -3343,23 +3350,24 @@ void MultiLoweringConfigGenerator::splitCommonInnerVectorTiles() {
33433350 const int64_t totalLoopNum = dimTracker.getTotalLoopNum ();
33443351
33453352 // Initialize inner parallel tiles.
3346- globalTileSizes[TilingLevel::VectorInnerParallelTiles].assign (totalLoopNum,
3347- 0 );
3348- globalScalableTileFlags[TilingLevel::VectorInnerParallelTiles]. assign (
3349- totalLoopNum, false );
3353+ globalTileSizes[IREE::CPU:: TilingLevel::VectorInnerParallelTiles].assign (
3354+ totalLoopNum, 0 );
3355+ globalScalableTileFlags[IREE::CPU:: TilingLevel::VectorInnerParallelTiles]
3356+ . assign ( totalLoopNum, false );
33503357
33513358 auto isReductionDim = [&](int64_t globalDimIdx) {
3352- return globalTileSizes[TilingLevel::VectorReductionTiles][globalDimIdx] > 0 ;
3359+ return globalTileSizes[IREE::CPU::TilingLevel::VectorReductionTiles]
3360+ [globalDimIdx] > 0 ;
33533361 };
33543362
33553363 SmallVector<int64_t > &commonSizes =
3356- globalTileSizes[TilingLevel::VectorCommonParallelTiles];
3357- SmallVector<bool > &commonFlags =
3358- globalScalableTileFlags[ TilingLevel::VectorCommonParallelTiles];
3364+ globalTileSizes[IREE::CPU:: TilingLevel::VectorCommonParallelTiles];
3365+ SmallVector<bool > &commonFlags = globalScalableTileFlags
3366+ [IREE::CPU:: TilingLevel::VectorCommonParallelTiles];
33593367 SmallVector<int64_t > &innerSizes =
3360- globalTileSizes[TilingLevel::VectorInnerParallelTiles];
3368+ globalTileSizes[IREE::CPU:: TilingLevel::VectorInnerParallelTiles];
33613369 SmallVector<bool > &innerFlags =
3362- globalScalableTileFlags[TilingLevel::VectorInnerParallelTiles];
3370+ globalScalableTileFlags[IREE::CPU:: TilingLevel::VectorInnerParallelTiles];
33633371 for (auto [globalDimIdx, size, flag] :
33643372 llvm::enumerate (commonSizes, commonFlags)) {
33653373 // "Common" means a parallel loop present either in all compute ops or in
@@ -3377,7 +3385,7 @@ void MultiLoweringConfigGenerator::splitCommonInnerVectorTiles() {
33773385}
33783386
33793387void MultiLoweringConfigGenerator::setNewTilingConfigs () {
3380- SmallVector<TilingLevel> tilingLevels;
3388+ SmallVector<IREE::CPU:: TilingLevel> tilingLevels;
33813389 tilingLevels.reserve (globalTileSizes.size ());
33823390 for (const auto &entry : globalTileSizes) {
33833391 tilingLevels.push_back (entry.first );
@@ -3390,7 +3398,7 @@ void MultiLoweringConfigGenerator::setNewTilingConfigs() {
33903398 int numLoops = iterTypes.size ();
33913399 SmallVector<IREE::CPU::LoweringConfigLevelInfo> newTilingInfo;
33923400 // Collect new tiling info.
3393- for (TilingLevel level : tilingLevels) {
3401+ for (IREE::CPU:: TilingLevel level : tilingLevels) {
33943402 SmallVector<int64_t > tileSizes (numLoops, 0 );
33953403 SmallVector<bool > scalableFlags (numLoops, false );
33963404 for (auto [pos, iterType] : llvm::enumerate (iterTypes)) {
@@ -3402,7 +3410,7 @@ void MultiLoweringConfigGenerator::setNewTilingConfigs() {
34023410 // - If the loop dimension is not a reduction but the current tiling
34033411 // level is `VectorReductionTiles`, skip it.
34043412 if ((iterType == utils::IteratorType::reduction) ^
3405- (level == TilingLevel::VectorReductionTiles)) {
3413+ (level == IREE::CPU:: TilingLevel::VectorReductionTiles)) {
34063414 continue ;
34073415 }
34083416 tileSizes[pos] = globalTileSizes[level][globalDimIdx];
0 commit comments