Skip to content

Commit 731e5ca

Browse files
CopilothanhanW
andauthored
[CPU] Switch IREE::CPU::TilingLevel to enum class (#22433)
Converts `IREE::CPU::TilingLevel` from a plain enum to `enum class` to eliminate namespace pollution and enforce type safety per [Abseil Tip #86](https://abseil.io/tips/86). ## What Changed This PR refactors the `TilingLevel` enum in the IREE CPU codegen to use `enum class` instead of a plain enum, which: - Eliminates namespace pollution by requiring qualified access (`IREE::CPU::TilingLevel::DistributionTiles` instead of `IREE::CPU::DistributionTiles`) - Enforces type safety by preventing implicit conversions - Uses default `int` type instead of `unsigned` following [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html#Integer_Types) **Pass option behavior change**: The `LLVMCPUTilePass` now accepts string-based enum values (e.g., `tiling-level=distribution`) instead of integer values (e.g., `tiling-level=0`), and defaults to `InvalidLevel` instead of `-1`. ## Key Changes - **IREECPUTypes.h**: Changed `enum TilingLevel` to `enum class TilingLevel` with default `int` type - **Passes.td**: Updated `LLVMCPUTilePass` to use `IREE::CPU::TilingLevel` type with string-based enum values (following the pattern from `LLVMCPUTileAndFuseProducerConsumerPass`) - **Passes.h**: Updated `createLLVMCPUTilePass` signature to use `IREE::CPU::TilingLevel` instead of `int64_t` - **All usage sites**: Updated to use fully qualified enum values (`IREE::CPU::TilingLevel::*`) and added `llvm::to_underlying()` where integer conversion is needed - **Tests**: Updated `tile.mlir` to use string values (`distribution`, `vector_common_parallel`) instead of integers - **Function signatures**: Updated `tileRootAndFuseProducerConsumer` to accept `IREE::CPU::TilingLevel` parameter Fixes #22432 --------- Signed-off-by: hanhanW <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: hanhanW <[email protected]> Co-authored-by: hanhanW <[email protected]>
1 parent 166cda3 commit 731e5ca

13 files changed

+133
-103
lines changed

compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,29 @@ constexpr StringLiteral kVectorCommonParallelConfigKey =
2929
constexpr StringLiteral kVectorReductionConfigKey = "vector_reduction";
3030
constexpr StringLiteral kVectorInnerParallelConfigKey = "vector_inner_parallel";
3131

32+
SmallVector<int> getTilingLevelsAsInts() {
33+
return llvm::to_vector(
34+
llvm::seq<int>(0, llvm::to_underlying(TilingLevel::MaxNumTileLevels)));
35+
}
36+
3237
/// Returns the entry key for the config in IREE::CPU::LoweringConfigAttr.
3338
/// Returns null if `level` is invalid.
3439
StringRef getTilingLevelName(TilingLevel level) {
3540
switch (level) {
36-
case DistributionTiles:
41+
case TilingLevel::DistributionTiles:
3742
return kDistributionConfigKey;
38-
case CacheParallelTiles:
43+
case TilingLevel::CacheParallelTiles:
3944
return kCacheParallelConfigKey;
40-
case CacheReductionTiles:
45+
case TilingLevel::CacheReductionTiles:
4146
return kCacheReductionConfigKey;
42-
case VectorCommonParallelTiles:
47+
case TilingLevel::VectorCommonParallelTiles:
4348
return kVectorCommonParallelConfigKey;
44-
case VectorReductionTiles:
49+
case TilingLevel::VectorReductionTiles:
4550
return kVectorReductionConfigKey;
46-
case VectorInnerParallelTiles:
51+
case TilingLevel::VectorInnerParallelTiles:
4752
return kVectorInnerParallelConfigKey;
48-
case MaxNumTileLevels:
49-
case InvalidLevel:
53+
case TilingLevel::MaxNumTileLevels:
54+
case TilingLevel::InvalidLevel:
5055
default:
5156
return StringRef();
5257
}
@@ -161,7 +166,7 @@ Attribute LoweringConfigAttr::getTilingLevelAttr(MLIRContext *ctx,
161166
SmallVector<LoweringConfigLevelInfo>
162167
LoweringConfigAttr::getAvailableTilingInfo() {
163168
SmallVector<LoweringConfigLevelInfo> result;
164-
for (unsigned i = 0, e = TilingLevel::MaxNumTileLevels; i < e; ++i) {
169+
for (auto i : IREE::CPU::getTilingLevelsAsInts()) {
165170
if (!hasTilingLevel(i)) {
166171
continue;
167172
}
@@ -177,7 +182,7 @@ LoweringConfigAttr::getAvailableTilingInfo() {
177182
}
178183

179184
SmallVector<int64_t> LoweringConfigAttr::getWorkgroupTileSizes() const {
180-
return getTileSizes(getConfig(), DistributionTiles);
185+
return getTileSizes(getConfig(), TilingLevel::DistributionTiles);
181186
}
182187

183188
SmallVector<OpFoldResult>
@@ -232,11 +237,11 @@ constexpr std::array vectorTilingLevels{TilingLevel::VectorCommonParallelTiles,
232237
std::optional<SmallVector<int64_t>> LoweringConfigAttr::getVectorSizes() const {
233238
SmallVector<int64_t> result;
234239
for (auto level : vectorTilingLevels) {
235-
if (!hasTilingLevel(level)) {
240+
if (!hasTilingLevel(static_cast<unsigned>(level))) {
236241
continue;
237242
}
238243
auto attr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
239-
getTilingLevelAttr(level));
244+
getTilingLevelAttr(static_cast<unsigned>(level)));
240245
if (result.empty()) {
241246
result.resize(attr.getSizes().size(), 0);
242247
}
@@ -256,11 +261,11 @@ std::optional<SmallVector<int64_t>> LoweringConfigAttr::getVectorSizes() const {
256261
SmallVector<bool> LoweringConfigAttr::getVectorScalableFlags() const {
257262
SmallVector<bool> result;
258263
for (auto level : vectorTilingLevels) {
259-
if (!hasTilingLevel(level)) {
264+
if (!hasTilingLevel(static_cast<unsigned>(level))) {
260265
continue;
261266
}
262267
auto attr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
263-
getTilingLevelAttr(level));
268+
getTilingLevelAttr(static_cast<unsigned>(level)));
264269
ArrayRef<bool> scalableFlags = attr.getScalableFlags();
265270
if (result.empty() && !scalableFlags.empty()) {
266271
result.resize(attr.getSizes().size(), false);

compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace mlir::iree_compiler::IREE::CPU {
1515

1616
/// Representation for all the supported tiling levels. All or just a subset of
1717
/// them may be available in a valid configuration.
18-
enum TilingLevel : unsigned {
18+
enum class TilingLevel {
1919
DistributionTiles = 0,
2020
CacheParallelTiles = 1,
2121
CacheReductionTiles = 2,
@@ -32,6 +32,9 @@ struct LoweringConfigLevelInfo {
3232
SmallVector<bool> scalableFlags;
3333
};
3434

35+
/// Returns all the tiling levels as integer values.
36+
SmallVector<int> getTilingLevelsAsInts();
37+
3538
/// Returns the corresponding key string for `level`.
3639
StringRef getTilingLevelName(TilingLevel level);
3740

compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ static llvm::cl::opt<bool> clEnableRiscvAggressiveDist(
116116
llvm::cl::init(false));
117117

118118
using IREE::Codegen::DispatchLoweringPassPipeline;
119-
using IREE::CPU::TilingLevel;
120119

121120
// Encodes the pre-processing strategy to be applied on a Linalg operation
122121
// before vectorization.
@@ -1003,21 +1002,23 @@ class LoweringConfigGenerator {
10031002
IREE::CPU::LoweringConfigAttr generateCPULoweringConfig() {
10041003
SmallVector<NamedAttribute> items;
10051004
if (!distTileSizes.empty()) {
1006-
appendLoweringConfigLevelAttr(items, TilingLevel::DistributionTiles,
1007-
distTileSizes);
1005+
appendLoweringConfigLevelAttr(
1006+
items, IREE::CPU::TilingLevel::DistributionTiles, distTileSizes);
10081007
} else if (auto op = dyn_cast<TilingInterface>(rootOp)) {
10091008
size_t numTilingDims = op.getLoopIteratorTypes().size();
1010-
appendLoweringConfigLevelAttr(items, TilingLevel::DistributionTiles,
1009+
appendLoweringConfigLevelAttr(items,
1010+
IREE::CPU::TilingLevel::DistributionTiles,
10111011
SmallVector<int64_t>(numTilingDims, 0));
10121012
}
10131013
if (!cacheTileSizes.empty()) {
10141014
SmallVector<int64_t> parallelTileSizes = cacheTileSizes;
10151015
SmallVector<int64_t> reductionTileSizes;
10161016
splitParallelAndReductionTiles(rootOp, parallelTileSizes,
10171017
reductionTileSizes);
1018-
appendLoweringConfigLevelAttr(items, TilingLevel::CacheParallelTiles,
1019-
parallelTileSizes);
1020-
appendLoweringConfigLevelAttr(items, TilingLevel::CacheReductionTiles,
1018+
appendLoweringConfigLevelAttr(
1019+
items, IREE::CPU::TilingLevel::CacheParallelTiles, parallelTileSizes);
1020+
appendLoweringConfigLevelAttr(items,
1021+
IREE::CPU::TilingLevel::CacheReductionTiles,
10211022
reductionTileSizes);
10221023
}
10231024
if (!vectorTileSizes.empty()) {
@@ -1029,11 +1030,12 @@ class LoweringConfigGenerator {
10291030
splitParallelAndReductionTiles(rootOp, parallelTileSizes,
10301031
reductionTileSizes, &parallelScalableFlags,
10311032
&reductionScalableFlags);
1032-
appendLoweringConfigLevelAttr(items,
1033-
TilingLevel::VectorCommonParallelTiles,
1034-
parallelTileSizes, parallelScalableFlags);
1035-
appendLoweringConfigLevelAttr(items, TilingLevel::VectorReductionTiles,
1036-
reductionTileSizes, reductionScalableFlags);
1033+
appendLoweringConfigLevelAttr(
1034+
items, IREE::CPU::TilingLevel::VectorCommonParallelTiles,
1035+
parallelTileSizes, parallelScalableFlags);
1036+
appendLoweringConfigLevelAttr(
1037+
items, IREE::CPU::TilingLevel::VectorReductionTiles,
1038+
reductionTileSizes, reductionScalableFlags);
10371039
}
10381040
return IREE::CPU::LoweringConfigAttr::get(ctx, items);
10391041
}
@@ -1044,10 +1046,10 @@ class LoweringConfigGenerator {
10441046
/// it means no tiling at all. Only the distribution tiling level is
10451047
/// unconditionally added because a root op expects the level to be present.
10461048
void appendLoweringConfigLevelAttr(SmallVectorImpl<NamedAttribute> &items,
1047-
TilingLevel level,
1049+
IREE::CPU::TilingLevel level,
10481050
ArrayRef<int64_t> tileSizes,
10491051
ArrayRef<bool> scalableFlags = {}) {
1050-
if (level != TilingLevel::DistributionTiles &&
1052+
if (level != IREE::CPU::TilingLevel::DistributionTiles &&
10511053
llvm::all_of(tileSizes, [](int64_t v) { return v == 0; })) {
10521054
return;
10531055
}
@@ -1309,13 +1311,14 @@ getNewLoweringConfig(MLIRContext *ctx,
13091311
bool setDistributionConfig) {
13101312
SmallVector<NamedAttribute> newItems;
13111313
for (auto [level, tileSizes, scalableFlags] : tilingInfo) {
1312-
if (!setDistributionConfig && level == TilingLevel::DistributionTiles) {
1314+
if (!setDistributionConfig &&
1315+
level == IREE::CPU::TilingLevel::DistributionTiles) {
13131316
continue;
13141317
}
13151318
// Distribution tile sizes is a must for rootOp, because it is the
13161319
// definition of root op. An operation that has distribution tile sizes is
13171320
// the root op. Other level can be dropped if all the tile sizes are zeros.
1318-
if (level != TilingLevel::DistributionTiles &&
1321+
if (level != IREE::CPU::TilingLevel::DistributionTiles &&
13191322
llvm::all_of(tileSizes, [](int64_t val) { return val == 0; })) {
13201323
continue;
13211324
}
@@ -3096,8 +3099,10 @@ class MultiLoweringConfigGenerator {
30963099
IterationDimTracker dimTracker;
30973100
// For each tiling level, store per-dimension tiling information.
30983101
// TilingLevel -> (global loop dimension index -> tile size / scalable flag)
3099-
llvm::SmallDenseMap<TilingLevel, SmallVector<int64_t>> globalTileSizes;
3100-
llvm::SmallDenseMap<TilingLevel, SmallVector<bool>> globalScalableTileFlags;
3102+
llvm::SmallDenseMap<IREE::CPU::TilingLevel, SmallVector<int64_t>>
3103+
globalTileSizes;
3104+
llvm::SmallDenseMap<IREE::CPU::TilingLevel, SmallVector<bool>>
3105+
globalScalableTileFlags;
31013106

31023107
// Store the vector parallel tile sizes preferred by non-root operations.
31033108
// Operation -> (global loop dimension index -> tile size)
@@ -3134,26 +3139,26 @@ MultiLoweringConfigGenerator::create(Operation *rootOperation,
31343139
void MultiLoweringConfigGenerator::loadRootLoweringConfig() {
31353140
const int64_t totalLoopNum = dimTracker.getTotalLoopNum();
31363141

3137-
auto loadTilingLevel = [&](TilingLevel level) {
3142+
auto loadTilingLevel = [&](IREE::CPU::TilingLevel level) {
31383143
SmallVector<int64_t> sizes;
31393144
SmallVector<bool> flags;
3140-
if (level == TilingLevel::DistributionTiles) {
3145+
if (level == IREE::CPU::TilingLevel::DistributionTiles) {
31413146
assert(rootLoweringConfig.hasWorkgroupTilingLevel() &&
31423147
"Expected root lowering config to have workgroup tiling level.");
31433148
sizes = rootLoweringConfig.getWorkgroupTileSizes();
31443149
flags.resize(sizes.size(), false);
3145-
} else if (level == TilingLevel::VectorCommonParallelTiles) {
3146-
if (rootLoweringConfig.hasTilingLevel(level)) {
3150+
} else if (level == IREE::CPU::TilingLevel::VectorCommonParallelTiles) {
3151+
if (rootLoweringConfig.hasTilingLevel(llvm::to_underlying(level))) {
31473152
auto attr = llvm::cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
3148-
rootLoweringConfig.getTilingLevelAttr(level));
3153+
rootLoweringConfig.getTilingLevelAttr(llvm::to_underlying(level)));
31493154
sizes.assign(attr.getSizes());
31503155
// Only `VectorCommonParallel` has scalable flags.
31513156
flags.assign(attr.getScalableFlags());
31523157
}
31533158
} else {
3154-
if (rootLoweringConfig.hasTilingLevel(level)) {
3155-
sizes =
3156-
rootLoweringConfig.getStaticTilingLevelSizes(level, rootOperation);
3159+
if (rootLoweringConfig.hasTilingLevel(llvm::to_underlying(level))) {
3160+
sizes = rootLoweringConfig.getStaticTilingLevelSizes(
3161+
llvm::to_underlying(level), rootOperation);
31573162
flags.resize(sizes.size(), false);
31583163
}
31593164
}
@@ -3180,9 +3185,8 @@ void MultiLoweringConfigGenerator::loadRootLoweringConfig() {
31803185
};
31813186

31823187
// Load all tiling levels.
3183-
for (int i = 0, e = static_cast<int>(TilingLevel::MaxNumTileLevels); i < e;
3184-
++i) {
3185-
loadTilingLevel(static_cast<TilingLevel>(i));
3188+
for (auto i : IREE::CPU::getTilingLevelsAsInts()) {
3189+
loadTilingLevel(static_cast<IREE::CPU::TilingLevel>(i));
31863190
}
31873191
}
31883192

@@ -3211,7 +3215,7 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32113215
ArrayRef<int64_t> rootOpGlobalDims =
32123216
dimTracker.getAllGlobalDimIdx(rootOperation);
32133217
auto adjust = [&](Operation *op, ArrayRef<int64_t> vecTileSize,
3214-
TilingLevel level,
3218+
IREE::CPU::TilingLevel level,
32153219
llvm::function_ref<int64_t(int64_t, int64_t)> updater) {
32163220
for (auto [pos, size] : llvm::enumerate(vecTileSize)) {
32173221
int64_t globalDimIdx = dimTracker.getGlobalDimIdx(op, pos);
@@ -3231,8 +3235,8 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32313235
if (isa<linalg::PackOp>(op)) {
32323236
// For pack op, align the distribution tile size and overwrite the
32333237
// vector parallel tile size.
3234-
adjust(op, vecTileSize, TilingLevel::DistributionTiles, align);
3235-
adjust(op, vecTileSize, TilingLevel::VectorCommonParallelTiles,
3238+
adjust(op, vecTileSize, IREE::CPU::TilingLevel::DistributionTiles, align);
3239+
adjust(op, vecTileSize, IREE::CPU::TilingLevel::VectorCommonParallelTiles,
32363240
overwrite);
32373241
} else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
32383242
// For unpack op, just overwrite the vector parallel tile size.
@@ -3261,7 +3265,7 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32613265
adjustedTileSize[dimExpr.getPosition()] = tileSize;
32623266
}
32633267
adjust(linalgOp.getOperation(), adjustedTileSize,
3264-
TilingLevel::VectorCommonParallelTiles, overwrite);
3268+
IREE::CPU::TilingLevel::VectorCommonParallelTiles, overwrite);
32653269
}
32663270
}
32673271

@@ -3285,8 +3289,8 @@ void MultiLoweringConfigGenerator::adjustTileSizesForRootOp() {
32853289
if (elementTypeSize == 1) {
32863290
SmallVector<int64_t> vecTileSize(rootOpGlobalDims.size(), 0);
32873291
vecTileSize.back() = 8;
3288-
adjust(rootOperation, vecTileSize, TilingLevel::VectorCommonParallelTiles,
3289-
align);
3292+
adjust(rootOperation, vecTileSize,
3293+
IREE::CPU::TilingLevel::VectorCommonParallelTiles, align);
32903294
}
32913295
}
32923296
}
@@ -3306,7 +3310,8 @@ void MultiLoweringConfigGenerator::fillTileSizesWithNonRootOps() {
33063310
for (auto [pos, size] : llvm::enumerate(vecTileSize)) {
33073311
int64_t globalDimIdx = dimTracker.getGlobalDimIdx(op, pos);
33083312
int64_t &tile =
3309-
globalTileSizes[TilingLevel::VectorCommonParallelTiles][globalDimIdx];
3313+
globalTileSizes[IREE::CPU::TilingLevel::VectorCommonParallelTiles]
3314+
[globalDimIdx];
33103315
// Only set the tile size if it hasn't been assigned yet.
33113316
if (tile == 0 && size > 0) {
33123317
tile = size;
@@ -3329,10 +3334,12 @@ void MultiLoweringConfigGenerator::getGenericReductionTileSizes() {
33293334
continue;
33303335
}
33313336
int64_t globalDimIdx = dimTracker.getGlobalDimIdx(op, pos);
3332-
globalTileSizes[TilingLevel::VectorReductionTiles][globalDimIdx] = size;
3333-
globalScalableTileFlags[TilingLevel::VectorReductionTiles][globalDimIdx] =
3334-
globalScalableTileFlags[TilingLevel::VectorCommonParallelTiles]
3335-
[globalDimIdx];
3337+
globalTileSizes[IREE::CPU::TilingLevel::VectorReductionTiles]
3338+
[globalDimIdx] = size;
3339+
globalScalableTileFlags
3340+
[IREE::CPU::TilingLevel::VectorReductionTiles]
3341+
[globalDimIdx] = globalScalableTileFlags
3342+
[IREE::CPU::TilingLevel::VectorCommonParallelTiles][globalDimIdx];
33363343
}
33373344
}
33383345
}
@@ -3343,23 +3350,24 @@ void MultiLoweringConfigGenerator::splitCommonInnerVectorTiles() {
33433350
const int64_t totalLoopNum = dimTracker.getTotalLoopNum();
33443351

33453352
// Initialize inner parallel tiles.
3346-
globalTileSizes[TilingLevel::VectorInnerParallelTiles].assign(totalLoopNum,
3347-
0);
3348-
globalScalableTileFlags[TilingLevel::VectorInnerParallelTiles].assign(
3349-
totalLoopNum, false);
3353+
globalTileSizes[IREE::CPU::TilingLevel::VectorInnerParallelTiles].assign(
3354+
totalLoopNum, 0);
3355+
globalScalableTileFlags[IREE::CPU::TilingLevel::VectorInnerParallelTiles]
3356+
.assign(totalLoopNum, false);
33503357

33513358
auto isReductionDim = [&](int64_t globalDimIdx) {
3352-
return globalTileSizes[TilingLevel::VectorReductionTiles][globalDimIdx] > 0;
3359+
return globalTileSizes[IREE::CPU::TilingLevel::VectorReductionTiles]
3360+
[globalDimIdx] > 0;
33533361
};
33543362

33553363
SmallVector<int64_t> &commonSizes =
3356-
globalTileSizes[TilingLevel::VectorCommonParallelTiles];
3357-
SmallVector<bool> &commonFlags =
3358-
globalScalableTileFlags[TilingLevel::VectorCommonParallelTiles];
3364+
globalTileSizes[IREE::CPU::TilingLevel::VectorCommonParallelTiles];
3365+
SmallVector<bool> &commonFlags = globalScalableTileFlags
3366+
[IREE::CPU::TilingLevel::VectorCommonParallelTiles];
33593367
SmallVector<int64_t> &innerSizes =
3360-
globalTileSizes[TilingLevel::VectorInnerParallelTiles];
3368+
globalTileSizes[IREE::CPU::TilingLevel::VectorInnerParallelTiles];
33613369
SmallVector<bool> &innerFlags =
3362-
globalScalableTileFlags[TilingLevel::VectorInnerParallelTiles];
3370+
globalScalableTileFlags[IREE::CPU::TilingLevel::VectorInnerParallelTiles];
33633371
for (auto [globalDimIdx, size, flag] :
33643372
llvm::enumerate(commonSizes, commonFlags)) {
33653373
// "Common" means a parallel loop present either in all compute ops or in
@@ -3377,7 +3385,7 @@ void MultiLoweringConfigGenerator::splitCommonInnerVectorTiles() {
33773385
}
33783386

33793387
void MultiLoweringConfigGenerator::setNewTilingConfigs() {
3380-
SmallVector<TilingLevel> tilingLevels;
3388+
SmallVector<IREE::CPU::TilingLevel> tilingLevels;
33813389
tilingLevels.reserve(globalTileSizes.size());
33823390
for (const auto &entry : globalTileSizes) {
33833391
tilingLevels.push_back(entry.first);
@@ -3390,7 +3398,7 @@ void MultiLoweringConfigGenerator::setNewTilingConfigs() {
33903398
int numLoops = iterTypes.size();
33913399
SmallVector<IREE::CPU::LoweringConfigLevelInfo> newTilingInfo;
33923400
// Collect new tiling info.
3393-
for (TilingLevel level : tilingLevels) {
3401+
for (IREE::CPU::TilingLevel level : tilingLevels) {
33943402
SmallVector<int64_t> tileSizes(numLoops, 0);
33953403
SmallVector<bool> scalableFlags(numLoops, false);
33963404
for (auto [pos, iterType] : llvm::enumerate(iterTypes)) {
@@ -3402,7 +3410,7 @@ void MultiLoweringConfigGenerator::setNewTilingConfigs() {
34023410
// - If the loop dimension is not a reduction but the current tiling
34033411
// level is `VectorReductionTiles`, skip it.
34043412
if ((iterType == utils::IteratorType::reduction) ^
3405-
(level == TilingLevel::VectorReductionTiles)) {
3413+
(level == IREE::CPU::TilingLevel::VectorReductionTiles)) {
34063414
continue;
34073415
}
34083416
tileSizes[pos] = globalTileSizes[level][globalDimIdx];

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ static IREE::CPU::LoweringConfigAttr getLoweringConfigWithNewVectorSizes(
105105
using TilingLevel = IREE::CPU::TilingLevel;
106106
MLIRContext *ctx = loweringConfig.getContext();
107107
SmallVector<NamedAttribute> items;
108-
for (unsigned i = 0, e = TilingLevel::MaxNumTileLevels; i < e; ++i) {
109-
auto level = static_cast<TilingLevel>(i);
110-
if (!loweringConfig.hasTilingLevel(level)) {
108+
for (auto i : IREE::CPU::getTilingLevelsAsInts()) {
109+
if (!loweringConfig.hasTilingLevel(i)) {
111110
continue;
112111
}
112+
auto level = static_cast<TilingLevel>(i);
113113
switch (level) {
114114
case TilingLevel::DistributionTiles:
115115
case TilingLevel::CacheParallelTiles:

0 commit comments

Comments
 (0)