@@ -101,18 +101,6 @@ class TileUsageAnalysis {
101
101
}); // walk on LoadTileOp
102
102
};
103
103
104
- uint getUsage (imex::xetile::InitTileOp op) {
105
- if (Usage.count (op))
106
- return Usage[op];
107
- return UsageType::None;
108
- }
109
-
110
- uint getUsage (imex::xetile::LoadTileOp op) {
111
- if (Usage.count (op))
112
- return Usage[op];
113
- return UsageType::None;
114
- }
115
-
116
104
bool isForDPASA (imex::xetile::LoadTileOp op) {
117
105
if (Usage.count (op)) {
118
106
return Usage[op] & UsageType::DPAS_A;
@@ -200,14 +188,93 @@ class TileUsageAnalysis {
200
188
llvm::DenseMap<mlir::Operation *, uint> Usage;
201
189
};
202
190
191
+ // This analysis is used to propagate the inner block size of an operator
192
+ // to its uses or users. Current implementation is to propagate the MMA
193
+ // size used by an MMA operator to the definition (InitTileOp) for its operands.
194
+ // TODO: This analysis can be extended to propagate the block size for other ops
195
+ // such that it can be used as a general analysis for other block size
196
+ // optimizations.
197
+ class PropagateAnalysis {
198
+ private:
199
+ llvm::DenseMap<mlir::Operation *, mlir::DenseI64ArrayAttr> OpAttrMap;
200
+
201
+ public:
202
+ PropagateAnalysis (mlir::Operation *op) {
203
+ op->walk <mlir::WalkOrder::PostOrder>([&](xetile::TileMMAOp op) {
204
+ mlir::Operation *operation = op.getOperation ();
205
+ for (auto value : operation->getOperands ()) {
206
+ auto packOp = value.getDefiningOp <xetile::TilePackOp>();
207
+ if (packOp) {
208
+ auto blkSZ = packOp.getInnerBlocksAttr ();
209
+ propagate (value, blkSZ);
210
+ }
211
+ }
212
+ });
213
+ }
214
+
215
+ bool maybeUpdated (mlir::Operation *op) { return OpAttrMap.count (op); }
216
+
217
+ mlir::DenseI64ArrayAttr getValue (mlir::Operation *op) {
218
+ if (OpAttrMap.count (op))
219
+ return OpAttrMap[op];
220
+ return {};
221
+ }
222
+
223
+ private:
224
+ mlir::Operation *getDefineOrParentOp (mlir::Value value) {
225
+ if (llvm::isa<mlir::OpResult>(value))
226
+ return value.getDefiningOp ();
227
+ if (auto arg = llvm::dyn_cast_or_null<mlir::BlockArgument>(value))
228
+ return arg.getOwner ()->getParentOp ();
229
+ return nullptr ;
230
+ };
231
+
232
+ mlir::Value getOperandForArg (mlir::scf::ForOp &forOp, mlir::Value &value) {
233
+ auto arg = llvm::dyn_cast<mlir::BlockArgument>(value);
234
+ if (arg && arg.getArgNumber () >= forOp.getNumInductionVars ()) {
235
+ auto &iterOperand = *forOp.getTiedLoopInit (arg);
236
+ auto numCtrlOperands = forOp.getNumControlOperands ();
237
+ auto operandIdx = iterOperand.getOperandNumber ();
238
+ return forOp.getInitArgs ()[operandIdx - numCtrlOperands];
239
+ }
240
+ return mlir::Value ();
241
+ };
242
+
243
+ void propagate (mlir::Value start, mlir::DenseI64ArrayAttr attr) {
244
+ llvm::SmallVector<mlir::Value> queue;
245
+ if (bool (start))
246
+ queue.push_back (start);
247
+
248
+ while (queue.size ()) {
249
+ auto value = queue.pop_back_val ();
250
+ if (!bool (value))
251
+ continue ;
252
+
253
+ auto *op = getDefineOrParentOp (value);
254
+
255
+ // stop when meet a function.
256
+ if (!op || llvm::isa<mlir::FunctionOpInterface>(op))
257
+ return ;
258
+
259
+ OpAttrMap[op] = attr;
260
+
261
+ if (auto forOp = llvm::dyn_cast<mlir::scf::ForOp>(op)) {
262
+ auto opr = getOperandForArg (forOp, value);
263
+ if (bool (opr))
264
+ queue.push_back (opr);
265
+ } else if (op->getNumOperands () == 1 ) {
266
+ queue.push_back (op->getOperand (0 ));
267
+ }
268
+ }
269
+ }
270
+ };
271
+
203
272
class XeTypeConverter : public mlir ::OneToNTypeConverter {
204
273
public:
205
- friend class XeConversionPattern ;
274
+ // friend class XeConversionPattern;
206
275
using mlir::OneToNTypeConverter::convertType;
207
276
208
- XeTypeConverter (mlir::MLIRContext &context,
209
- TileUsageAnalysis *analysis = nullptr )
210
- : context(context), usageAnalysis(analysis) {
277
+ XeTypeConverter (mlir::MLIRContext &context) : context(context) {
211
278
addConversion ([&](xetile::TileType tileTy,
212
279
llvm::SmallVectorImpl<mlir::Type> &resultTypes)
213
280
-> std::optional<mlir::LogicalResult> {
@@ -235,72 +302,95 @@ class XeTypeConverter : public mlir::OneToNTypeConverter {
235
302
236
303
private:
237
304
mlir::MLIRContext &context;
238
-
239
- protected:
240
- TileUsageAnalysis *usageAnalysis;
241
305
};
242
306
243
307
// A simple mlir::RewritePattern wrapper with methods for accessing UsageType
308
+ template <typename AnalysisT>
244
309
class XeConversionPattern : public mlir ::RewritePattern {
245
310
public:
246
311
using mlir::RewritePattern::RewritePattern;
247
312
248
313
template <typename ... Args>
249
- XeConversionPattern (imex::XeTypeConverter &typeConverter, Args &&...args)
314
+ XeConversionPattern (imex::XeTypeConverter &typeConverter, AnalysisT &analysis,
315
+ Args &&...args)
250
316
: mlir::RewritePattern(std::forward<Args>(args)...),
251
- typeConverter (typeConverter) {}
317
+ typeConverter (typeConverter), analysis(analysis) {}
252
318
253
319
virtual mlir::LogicalResult
254
320
matchAndRewrite (mlir::Operation *op,
255
321
mlir::PatternRewriter &rewriter) const override {
256
322
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
257
323
};
258
324
325
+ imex::XeTypeConverter &getTypeConverter () const { return typeConverter; }
326
+
327
+ template <typename ConverterTy>
328
+ std::enable_if_t <std::is_base_of<mlir::TypeConverter, ConverterTy>::value,
329
+ ConverterTy &>
330
+ getTypeConverter () const {
331
+ return static_cast <ConverterTy &>(typeConverter);
332
+ }
333
+
334
+ protected:
335
+ imex::XeTypeConverter &typeConverter;
336
+ AnalysisT &analysis;
337
+
338
+ template <typename = typename std::enable_if<
339
+ std::is_same_v<AnalysisT, PropagateAnalysis>>>
340
+ mlir::DenseI64ArrayAttr getValue (mlir::Operation *op) const {
341
+ if (op)
342
+ return llvm::cast<PropagateAnalysis>(analysis).getValue (op);
343
+ return {};
344
+ }
345
+
346
+ template <typename = typename std::enable_if<
347
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
259
348
bool isForDPASA (imex::xetile::LoadTileOp op) const {
260
- return typeConverter. usageAnalysis -> isForDPASA (op);
349
+ return llvm::cast<TileUsageAnalysis>(analysis). isForDPASA (op);
261
350
}
262
351
352
+ template <typename = typename std::enable_if<
353
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
263
354
bool isForDPASB (imex::xetile::LoadTileOp op) const {
264
- return typeConverter. usageAnalysis -> isForDPASB (op);
355
+ return llvm::cast<TileUsageAnalysis>(analysis). isForDPASB (op);
265
356
}
266
357
358
+ template <typename = typename std::enable_if<
359
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
267
360
bool isForDPASC (imex::xetile::LoadTileOp op) const {
268
- return typeConverter. usageAnalysis -> isForDPASC (op);
361
+ return llvm::cast<TileUsageAnalysis>(analysis). isForDPASC (op);
269
362
}
270
363
364
+ template <typename = typename std::enable_if<
365
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
271
366
bool isForLoad (imex::xetile::InitTileOp op) const {
272
- return typeConverter. usageAnalysis -> isForLoad (op);
367
+ return llvm::cast<TileUsageAnalysis>(analysis). isForLoad (op);
273
368
}
274
369
370
+ template <typename = typename std::enable_if<
371
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
275
372
bool isForStore (imex::xetile::InitTileOp op) const {
276
- return typeConverter. usageAnalysis -> isForStore (op);
373
+ return llvm::cast<TileUsageAnalysis>(analysis). isForStore (op);
277
374
}
278
375
376
+ template <typename = typename std::enable_if<
377
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
279
378
bool isForPrefetch (imex::xetile::InitTileOp op) const {
280
- return typeConverter. usageAnalysis -> isForPrefetch (op);
379
+ return llvm::cast<TileUsageAnalysis>(analysis). isForPrefetch (op);
281
380
}
282
381
382
+ template <typename = typename std::enable_if<
383
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
283
384
bool isForLoadAndPrefetch (imex::xetile::InitTileOp op) const {
284
- return typeConverter. usageAnalysis -> isForLoadAndPrefetch (op);
385
+ return llvm::cast<TileUsageAnalysis>(analysis). isForLoadAndPrefetch (op);
285
386
}
286
387
388
+ template <typename = typename std::enable_if<
389
+ std::is_same_v<AnalysisT, TileUsageAnalysis>>>
287
390
bool isForLoadAndStore (imex::xetile::InitTileOp op) const {
288
- return typeConverter.usageAnalysis ->isForLoadAndStore (op);
289
- }
290
-
291
- imex::XeTypeConverter &getTypeConverter () const { return typeConverter; }
292
-
293
- template <typename ConverterTy>
294
- std::enable_if_t <std::is_base_of<mlir::TypeConverter, ConverterTy>::value,
295
- ConverterTy &>
296
- getTypeConverter () const {
297
- return static_cast <ConverterTy &>(typeConverter);
391
+ return llvm::cast<TileUsageAnalysis>(analysis).isForLoadAndStore (op);
298
392
}
299
-
300
- protected:
301
- imex::XeTypeConverter &typeConverter;
302
393
};
303
-
304
394
} // namespace imex
305
395
306
396
#endif
0 commit comments