31
31
#include " mlir/Dialect/Func/IR/FuncOps.h"
32
32
#include " mlir/Dialect/Linalg/Passes.h"
33
33
#include " mlir/Dialect/MemRef/Transforms/Passes.h"
34
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
35
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
34
36
#include " mlir/Pass/PassManager.h"
35
37
#include " mlir/Transforms/Passes.h"
36
38
@@ -150,126 +152,129 @@ addTileAndDistributePasses(OpPassManager &funcPassManager,
150
152
// ===---------------------------------------------------------------------===//
151
153
152
154
static bool isValidInterchange (ArrayRef<int64_t > interchange, int numLoops) {
153
- if (interchange.empty ())
155
+ if (interchange.empty ()) {
154
156
return true ;
155
- llvm::SmallDenseSet<int64_t > s;
156
- s.insert (interchange.begin (), interchange.end ());
157
- for (int i = 0 ; i < numLoops; ++i) {
158
- if (!s.contains (i))
159
- return false ;
160
157
}
161
- return true ;
158
+ return isPermutationVector (interchange) && interchange. size () == numLoops ;
162
159
}
163
160
164
- // TODO(hanchung): Refresh the verifier after all the pipelines use
165
- // IREE::CPU::LoweringConfigAttr.
166
- LogicalResult verifyDoubleTilingExpertPassPipelineConfig (
167
- Operation *op, TilingConfig &tilingConfig,
168
- IREE::Codegen::TranslationInfoAttr translationInfo,
169
- ArrayRef<int64_t > workgroupSize) {
170
- if (!workgroupSize.empty ()) {
171
- return op->emitOpError (
172
- " expected workgroup size to be empty for CPU pipelines" );
173
- }
174
-
175
- // Verify that the translation info is using the right pipeline.
176
- if (translationInfo.getDispatchLoweringPassPipeline () !=
177
- IREE::Codegen::DispatchLoweringPassPipeline::CPUDoubleTilingExpert) {
178
- return op->emitOpError (" expected pipeline in translation_info to be " )
179
- << stringifyEnum (IREE::Codegen::DispatchLoweringPassPipeline::
180
- CPUDoubleTilingExpert);
181
- }
161
+ LogicalResult verifyMultiTilingExpertPassPipelineConfig (
162
+ Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
182
163
183
- if (tilingConfig. getNumTilingLevels () == 6 ) {
184
- // TODO: update verification.
164
+ auto interfaceOp = dyn_cast_or_null<TilingInterface>(op);
165
+ if (!interfaceOp) {
185
166
return success ();
186
167
}
187
168
188
- if (tilingConfig.getNumTilingLevels () != 4 ) {
189
- return op->emitOpError (" expected four tiling levels, got " )
190
- << tilingConfig.getNumTilingLevels ();
169
+ // Collects parallel loops.
170
+ llvm::SmallDenseSet<unsigned > pLoopsSet;
171
+ for (auto [index, iteratorType] :
172
+ llvm::enumerate (interfaceOp.getLoopIteratorTypes ())) {
173
+ if (iteratorType == utils::IteratorType::parallel) {
174
+ pLoopsSet.insert (index);
175
+ }
191
176
}
192
177
193
- auto interfaceOp = dyn_cast_or_null<TilingInterface>(op);
194
- if (interfaceOp) {
195
- llvm::SmallDenseSet<unsigned > pLoopsSet;
196
- for (auto [index, iteratorType] :
197
- llvm::enumerate (interfaceOp.getLoopIteratorTypes ())) {
198
- if (iteratorType == utils::IteratorType::parallel) {
199
- pLoopsSet.insert (index);
200
- }
178
+ for (int i = 0 , e = IREE::CPU::TilingLevel::MaxNumTileLevels; i < e; ++i) {
179
+ if (!loweringConfig.hasTilingLevel (i)) {
180
+ continue ;
201
181
}
202
182
203
- SmallVector<int64_t > secondLevelTileSizes;
204
- std::tie (secondLevelTileSizes, std::ignore) =
205
- tilingConfig.getVectorCommonParallelSizes ();
206
- for (auto [index, tileSize] : llvm::enumerate (secondLevelTileSizes)) {
207
- if (tileSize != 0 && !pLoopsSet.contains (index)) {
208
- return op->emitOpError (
209
- " expected only parallel dims to be set in the second tiling "
210
- " level, got " )
211
- << index << " -th tile size set" ;
183
+ auto level = static_cast <IREE::CPU::TilingLevel>(i);
184
+ auto tilingLevelAttr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
185
+ loweringConfig.getTilingLevelAttr (level));
186
+ switch (level) {
187
+ case IREE::CPU::TilingLevel::DistributionTiles:
188
+ case IREE::CPU::TilingLevel::CacheParallelTiles:
189
+ case IREE::CPU::TilingLevel::VectorCommonParallelTiles:
190
+ case IREE::CPU::TilingLevel::VectorInnerParallelTiles: {
191
+ for (auto [index, tileSize] :
192
+ llvm::enumerate (tilingLevelAttr.getSizes ())) {
193
+ if (tileSize != 0 && !pLoopsSet.contains (index)) {
194
+ return op->emitOpError (
195
+ " expected only parallel dims to be set in the " )
196
+ << IREE::CPU::getTilingLevelName (level)
197
+ << " tiling level, but tile size at index (" << index
198
+ << " ) was also set" ;
199
+ }
212
200
}
201
+ break ;
213
202
}
214
-
215
- SmallVector<int64_t > thirdLevelTileSizes;
216
- std::tie (thirdLevelTileSizes, std::ignore) =
217
- tilingConfig.getVectorReductionSizes ();
218
- for (auto [index, tileSize] : llvm::enumerate (thirdLevelTileSizes)) {
219
- if (tileSize != 0 && pLoopsSet.contains (index)) {
220
- return op->emitOpError (
221
- " expected only reduction dims to be set in the third tiling "
222
- " level, got " )
223
- << index << " -th tile size set" ;
203
+ case IREE::CPU::TilingLevel::CacheReductionTiles:
204
+ case IREE::CPU::TilingLevel::VectorReductionTiles: {
205
+ for (auto [index, tileSize] :
206
+ llvm::enumerate (tilingLevelAttr.getSizes ())) {
207
+ if (tileSize != 0 && pLoopsSet.contains (index)) {
208
+ return op->emitOpError (
209
+ " expected only reduction dims to be set in the " )
210
+ << IREE::CPU::getTilingLevelName (level)
211
+ << " tiling level, but tile size at index (" << index
212
+ << " ) was also set" ;
213
+ }
224
214
}
215
+ break ;
225
216
}
226
- }
217
+ case IREE::CPU::TilingLevel::MaxNumTileLevels:
218
+ case IREE::CPU::TilingLevel::InvalidLevel:
219
+ break ;
220
+ };
227
221
228
- // Verify interchange.
229
- for (int level = 0 ; level < tilingConfig.getNumTilingLevels (); level++) {
230
- IREE::Codegen::LoweringConfigTilingLevelAttr attr =
231
- tilingConfig.getTilingLevelAttr (level);
232
- ArrayRef<int64_t > interchange = attr.getInterchange ();
233
- size_t expectedSize = attr.getSizes ().size ();
234
- if (!interchange.empty () &&
235
- !isValidInterchange (interchange, expectedSize)) {
222
+ ArrayRef<int64_t > interchange = tilingLevelAttr.getInterchange ();
223
+ size_t expectedSize = tilingLevelAttr.getSizes ().size ();
224
+ if (!isValidInterchange (interchange, expectedSize)) {
236
225
return op->emitOpError (" expected [0, " )
237
- << expectedSize << " ) to be set exactly once in interchange # "
238
- << level;
226
+ << expectedSize << " ) to be set exactly once in interchange for "
227
+ << IREE::CPU::getTilingLevelName ( level) << " tiling level " ;
239
228
}
240
229
}
230
+
241
231
return success ();
242
232
}
243
233
244
234
LogicalResult verifyConvTileAndDecomposeExpertConfig (
245
- Operation *op, TilingConfig &tilingConfig,
246
- IREE::Codegen::TranslationInfoAttr translationInfo,
247
- ArrayRef<int64_t > workgroupSize) {
248
- if (!isa<linalg::ConvolutionOpInterface>(op))
249
- return success ();
250
-
251
- if (tilingConfig.getNumTilingLevels () == 6 ) {
252
- // TODO: update verification.
235
+ Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
236
+ if (!isa<linalg::ConvolutionOpInterface>(op)) {
253
237
return success ();
254
238
}
255
239
256
- if (tilingConfig.getNumTilingLevels () != 3 ) {
257
- return op->emitOpError (" expected three tiling levels, got " )
258
- << tilingConfig.getNumTilingLevels ();
259
- }
240
+ auto getTileSizeAtIndex = [](ArrayRef<int64_t > sizes,
241
+ ArrayRef<bool > scalableFlags,
242
+ unsigned index) -> std::pair<int64_t , bool > {
243
+ return std::make_pair (sizes[index],
244
+ index < scalableFlags.size () && scalableFlags[index]);
245
+ };
260
246
247
+ SmallVector<IREE::CPU::TilingLevel> requiredLevels = {
248
+ IREE::CPU::DistributionTiles, IREE::CPU::VectorCommonParallelTiles,
249
+ IREE::CPU::VectorReductionTiles};
261
250
linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
262
- SmallVector<int64_t > shape = linalgOp.getStaticLoopRanges ();
263
- for (auto sizes : tilingConfig.getTileSizes ()) {
264
- for (auto [i, size] : llvm::enumerate (sizes)) {
265
- if (size == 1 )
266
- shape[i] = 1 ;
267
- if (shape[i] == -1 || size == 0 )
251
+ SmallVector<int64_t > shapeAfterTiling = linalgOp.getStaticLoopRanges ();
252
+ for (auto level : requiredLevels) {
253
+ if (!loweringConfig.hasTilingLevel (level)) {
254
+ return op->emitOpError (" expected " )
255
+ << IREE::CPU::getTilingLevelName (level) << " is set" ;
256
+ }
257
+ auto tilingLevelAttr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
258
+ loweringConfig.getTilingLevelAttr (level));
259
+ for (size_t i = 0 , e = tilingLevelAttr.getSizes ().size (); i < e; ++i) {
260
+ auto [size, scalableFlag] = getTileSizeAtIndex (
261
+ tilingLevelAttr.getSizes (), tilingLevelAttr.getScalableFlags (), i);
262
+ if (scalableFlag) {
263
+ shapeAfterTiling[i] = ShapedType::kDynamic ;
264
+ continue ;
265
+ }
266
+ if (size == 1 ) {
267
+ shapeAfterTiling[i] = 1 ;
268
+ continue ;
269
+ }
270
+ if (ShapedType::isDynamicShape (shapeAfterTiling[i]) ||
271
+ ShapedType::isDynamic (size) || size == 0 ) {
268
272
continue ;
269
- if (shape[i] % size != 0 ) {
270
- shape[i] = -1 ;
273
+ }
274
+ if (shapeAfterTiling[i] % size != 0 ) {
275
+ shapeAfterTiling[i] = ShapedType::kDynamic ;
271
276
} else {
272
- shape [i] = size;
277
+ shapeAfterTiling [i] = size;
273
278
}
274
279
}
275
280
}
@@ -281,27 +286,27 @@ LogicalResult verifyConvTileAndDecomposeExpertConfig(
281
286
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
282
287
linalg::PoolingNhwcMaxUnsignedOp, linalg::PoolingNhwcMinOp,
283
288
linalg::PoolingNhwcMinUnsignedOp>([&](auto ) {
284
- // Shape : N, OH, OW, OC, KH, KW, (IC)
285
- khSize = shape [4 ];
286
- kwSize = shape [5 ];
287
- ohSize = shape [1 ];
288
- owSize = shape [2 ];
289
+ // shape : N, OH, OW, OC, KH, KW, (IC)
290
+ khSize = shapeAfterTiling [4 ];
291
+ kwSize = shapeAfterTiling [5 ];
292
+ ohSize = shapeAfterTiling [1 ];
293
+ owSize = shapeAfterTiling [2 ];
289
294
return success ();
290
295
})
291
296
.Case <linalg::Conv2DNchwFchwOp>([&](auto ) {
292
- // Shape : N, OC, OH, OW, (IC), KH, KW
293
- khSize = shape [5 ];
294
- kwSize = shape [6 ];
295
- ohSize = shape [2 ];
296
- owSize = shape [3 ];
297
+ // shape : N, OC, OH, OW, (IC), KH, KW
298
+ khSize = shapeAfterTiling [5 ];
299
+ kwSize = shapeAfterTiling [6 ];
300
+ ohSize = shapeAfterTiling [2 ];
301
+ owSize = shapeAfterTiling [3 ];
297
302
return success ();
298
303
})
299
304
.Case <linalg::PoolingNchwSumOp, linalg::PoolingNchwMaxOp>([&](auto ) {
300
- // Shape : N, OC, OH, OW, KH, KW
301
- khSize = shape [4 ];
302
- kwSize = shape [5 ];
303
- ohSize = shape [2 ];
304
- owSize = shape [3 ];
305
+ // shape : N, OC, OH, OW, KH, KW
306
+ khSize = shapeAfterTiling [4 ];
307
+ kwSize = shapeAfterTiling [5 ];
308
+ ohSize = shapeAfterTiling [2 ];
309
+ owSize = shapeAfterTiling [3 ];
305
310
return success ();
306
311
})
307
312
.Default ([&](auto ) { return failure (); });
0 commit comments