@@ -1063,30 +1063,12 @@ class LoweringConfigGenerator {
1063
1063
};
1064
1064
1065
1065
// / Returns the same lowering_config attribute with the updated tile sizes and
1066
- // / scalable tile flags. The `setDistrubtionConfig` flag is only available when
1067
- // / `origLoweringConfig is a IREE::CPU::LoweringConfigAttr. The distribution
1068
- // / tiling sizes is not set if it is false.
1069
- // / See `Codegen/Common/TileSizeSelection.h` for the convention of mapping
1070
- // / between tiling levels.
1071
- static IREE::Codegen::LoweringConfigAttrInterface getNewLoweringConfig (
1072
- IREE::Codegen::LoweringConfigAttrInterface origLoweringConfig,
1073
- ArrayRef<IREE::CPU::LoweringConfigLevelInfo> tilingInfo,
1074
- bool setDistributionConfig) {
1075
- assert ((isa<IREE::Codegen::LoweringConfigAttr, IREE::CPU::LoweringConfigAttr>(
1076
- origLoweringConfig)));
1077
- MLIRContext *ctx = origLoweringConfig.getContext ();
1078
- if (isa<IREE::Codegen::LoweringConfigAttr>(origLoweringConfig)) {
1079
- TileSizesListType tileSizesList;
1080
- ScalableTileFlagsListType scalableTileFlagsList;
1081
- for (auto [level, tileSizes, scalableFlags] : tilingInfo) {
1082
- (void )level;
1083
- tileSizesList.push_back (tileSizes);
1084
- scalableTileFlagsList.push_back (scalableFlags);
1085
- }
1086
- return IREE::Codegen::LoweringConfigAttr::get (ctx, tileSizesList,
1087
- scalableTileFlagsList);
1088
- }
1089
-
1066
+ // / scalable tile flags. The distribution tiling sizes is not set if it is
1067
+ // / false.
1068
+ static IREE::Codegen::LoweringConfigAttrInterface
1069
+ getNewLoweringConfig (MLIRContext *ctx,
1070
+ ArrayRef<IREE::CPU::LoweringConfigLevelInfo> tilingInfo,
1071
+ bool setDistributionConfig) {
1090
1072
SmallVector<NamedAttribute> newItems;
1091
1073
for (auto [level, tileSizes, scalableFlags] : tilingInfo) {
1092
1074
if (!setDistributionConfig && level == TilingLevel::DistributionTiles) {
@@ -2751,13 +2733,16 @@ static LogicalResult
2751
2733
adjustTileSizesForUnPackOp (mlir::FunctionOpInterface entryPointFn,
2752
2734
Operation *rootOp) {
2753
2735
auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
2754
- if (!linalgOp)
2736
+ if (!linalgOp) {
2737
+ return success ();
2738
+ }
2739
+ auto loweringConfig =
2740
+ getLoweringConfig<IREE::CPU::LoweringConfigAttr>(linalgOp);
2741
+ if (!loweringConfig) {
2742
+ // Tile size adjustment is only available when the rootOp uses
2743
+ // IREE::CPU::LoweringConfigAttr.
2755
2744
return success ();
2756
- IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
2757
- getLoweringConfig (linalgOp);
2758
- std::unique_ptr<TilingConfig> tilingConfig =
2759
- TilingConfig::create (loweringConfig);
2760
- TileSizesListType tileSizesList = tilingConfig->getTileSizes ();
2745
+ }
2761
2746
2762
2747
bool foundUnPackOp = false ;
2763
2748
SmallVector<int64_t > alignedSizes (linalgOp.getNumLoops (), 1 );
@@ -2792,7 +2777,7 @@ adjustTileSizesForUnPackOp(mlir::FunctionOpInterface entryPointFn,
2792
2777
2793
2778
// Fixup for making tileSizes be multiple of inner_tile_sizes.
2794
2779
SmallVector<IREE::CPU::LoweringConfigLevelInfo> tilingInfo =
2795
- tilingConfig-> getTilingLevelInfo ();
2780
+ loweringConfig. getAvailableTilingInfo ();
2796
2781
for (IREE::CPU::LoweringConfigLevelInfo &info : tilingInfo) {
2797
2782
SmallVector<int64_t > &tileSizes = info.sizes ;
2798
2783
for (auto idx : llvm::seq<int64_t >(0 , tileSizes.size ())) {
@@ -2824,7 +2809,7 @@ adjustTileSizesForUnPackOp(mlir::FunctionOpInterface entryPointFn,
2824
2809
}
2825
2810
2826
2811
IREE::Codegen::LoweringConfigAttrInterface newLoweringConfig =
2827
- getNewLoweringConfig (loweringConfig , tilingInfo,
2812
+ getNewLoweringConfig (rootOp-> getContext () , tilingInfo,
2828
2813
/* setDistributionConfig=*/ true );
2829
2814
return setOpConfigAndEntryPointFnTranslation (
2830
2815
entryPointFn, rootOp, newLoweringConfig, pipeline, /* workgroupSize=*/ {},
@@ -2947,7 +2932,13 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
2947
2932
return success ();
2948
2933
}
2949
2934
2950
- auto rootLoweringConfig = getLoweringConfig (rootOperation);
2935
+ auto rootLoweringConfig =
2936
+ getLoweringConfig<IREE::CPU::LoweringConfigAttr>(rootOperation);
2937
+ if (!rootLoweringConfig) {
2938
+ // Propagation is only available for IREE::CPU::LoweringConfigAttr.
2939
+ return success ();
2940
+ }
2941
+
2951
2942
SmallVector<int64_t > distTileSizes, parallelVecTileSizes;
2952
2943
SmallVector<bool > distScalableTileSizes, parallelVecScalableTileSizes;
2953
2944
assert (rootLoweringConfig.hasWorkgroupTilingLevel ());
@@ -3084,19 +3075,17 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
3084
3075
}
3085
3076
3086
3077
// Set the lowering configs with new tile sizes.
3087
- // TODO(hanchung): Deprecate TilingConfig from the file.
3088
- std::unique_ptr<TilingConfig> tilingConfig =
3089
- TilingConfig::create (rootLoweringConfig);
3090
3078
for (auto op : computeOps) {
3091
3079
int numLoops = cast<TilingInterface>(op).getLoopIteratorTypes ().size ();
3092
3080
SmallVector<IREE::CPU::LoweringConfigLevelInfo> newTilingInfo;
3093
3081
// For root op, we patch the adjusted tile sizes on its original tiling
3094
3082
// config.
3095
3083
if (op == rootOperation) {
3096
- newTilingInfo = tilingConfig-> getTilingLevelInfo ();
3084
+ newTilingInfo = rootLoweringConfig. getAvailableTilingInfo ();
3097
3085
updateOrAddTilingLevelInfo (newTilingInfo, IREE::CPU::DistributionTiles,
3098
3086
distTileSizes, distScalableTileSizes);
3099
- if (tilingConfig->getNumTilingLevels () > 1 ) {
3087
+ if (rootLoweringConfig.hasTilingLevel (
3088
+ IREE::CPU::VectorCommonParallelTiles)) {
3100
3089
updateOrAddTilingLevelInfo (
3101
3090
newTilingInfo, IREE::CPU::VectorCommonParallelTiles,
3102
3091
commonVecTileSizes, commonVecScalableTileFlags);
@@ -3110,27 +3099,30 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
3110
3099
distTileSizes, falseVec);
3111
3100
// The cache level tiling sizes are not adjusted, so we use the
3112
3101
// config from the rootOp directly.
3113
- if (tilingConfig->isValidLevel (IREE::CPU::CacheParallelTiles)) {
3114
- updateOrAddTilingLevelInfo (newTilingInfo, IREE::CPU::CacheParallelTiles,
3115
- tilingConfig->getCacheParallelSizes (),
3116
- falseVec);
3102
+ if (rootLoweringConfig.hasTilingLevel (IREE::CPU::CacheParallelTiles)) {
3103
+ updateOrAddTilingLevelInfo (
3104
+ newTilingInfo, IREE::CPU::CacheParallelTiles,
3105
+ rootLoweringConfig.getStaticTilingLevelSizes (
3106
+ IREE::CPU::CacheParallelTiles, rootOperation),
3107
+ falseVec);
3117
3108
}
3118
- if (tilingConfig-> isValidLevel (IREE::CPU::CacheReductionTiles)) {
3109
+ if (rootLoweringConfig. hasTilingLevel (IREE::CPU::CacheReductionTiles)) {
3119
3110
updateOrAddTilingLevelInfo (
3120
3111
newTilingInfo, IREE::CPU::CacheReductionTiles,
3121
- tilingConfig->getCacheReductionSizes (), falseVec);
3112
+ rootLoweringConfig.getStaticTilingLevelSizes (
3113
+ IREE::CPU::CacheReductionTiles, rootOperation),
3114
+ falseVec);
3122
3115
}
3123
3116
updateOrAddTilingLevelInfo (
3124
3117
newTilingInfo, IREE::CPU::VectorCommonParallelTiles,
3125
3118
commonVecTileSizes, commonVecScalableTileFlags);
3126
3119
bool setUpOK =
3127
3120
TypeSwitch<Operation *, bool >(op)
3128
3121
.Case <linalg::PackOp>([&](auto packOp) {
3129
- for (ArrayRef<bool > flags :
3130
- tilingConfig->getScalableTileFlags ()) {
3131
- // TODO: Handle scalable flags
3132
- if (llvm::any_of (flags, [&](bool flag) { return flag; }))
3133
- return false ;
3122
+ // TODO: Handle scalable flags
3123
+ if (llvm::any_of (rootLoweringConfig.getVectorScalableFlags (),
3124
+ [&](bool flag) { return flag; })) {
3125
+ return false ;
3134
3126
}
3135
3127
updateOrAddTilingLevelInfo (newTilingInfo,
3136
3128
IREE::CPU::VectorReductionTiles,
@@ -3202,7 +3194,7 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
3202
3194
return lhs.level < rhs.level ;
3203
3195
});
3204
3196
IREE::Codegen::LoweringConfigAttrInterface config =
3205
- getNewLoweringConfig (rootLoweringConfig , newTilingInfo,
3197
+ getNewLoweringConfig (rootOperation-> getContext () , newTilingInfo,
3206
3198
/* setDistributionConfig=*/ op == rootOperation);
3207
3199
setLoweringConfig (op, config);
3208
3200
}
0 commit comments