@@ -379,11 +379,18 @@ struct SgPrefetchTileOpPattern
379
379
XeGPUOneToNPatterRewriter &rewriter) const override {
380
380
auto tileTy = op.getTile ().getType ();
381
381
auto tiles = adaptor.getTile ();
382
- if (tileTy.getRank () != 4 )
382
+ auto innerBlocks = tileTy.getInnerBlocks ();
383
+
384
+ if (tileTy.getRank () != 2 )
383
385
return mlir::failure ();
384
- auto shape = tileTy.getShape ();
385
386
386
- if (shape[0 ] * shape[1 ] != (int64_t )tiles.size ()) {
387
+ if (!innerBlocks || innerBlocks.size () != 2 )
388
+ return mlir::failure ();
389
+
390
+ auto shape = tileTy.getShape ();
391
+ auto expectedNumTensorDescs =
392
+ (shape[0 ] / innerBlocks[0 ]) * (shape[1 ] / innerBlocks[1 ]);
393
+ if (expectedNumTensorDescs != (int64_t )tiles.size ()) {
387
394
op.emitOpError (" Failed to lower LoadTileOp because shape[0] * shape[1] "
388
395
" != sources.size()." );
389
396
return mlir::failure ();
@@ -396,12 +403,9 @@ struct SgPrefetchTileOpPattern
396
403
auto L3 = xegpu::CacheReadHintAttr::get (op.getContext (),
397
404
xegpu::CacheReadHint::CACHED);
398
405
399
- for (int i = 0 ; i < shape[0 ]; i++) {
400
- for (int j = 0 ; j < shape[1 ]; j++) {
401
- auto tile = tiles[i * shape[1 ] + j];
402
- rewriter.create <xegpu::PrefetchNDOp>(op.getLoc (), tile, L1, L2, L3,
403
- imex::xegpu::Mode::VC);
404
- }
406
+ for (auto tile : tiles) {
407
+ rewriter.create <xegpu::PrefetchNDOp>(op.getLoc (), tile, L1, L2, L3,
408
+ imex::xegpu::Mode::VC);
405
409
}
406
410
407
411
rewriter.eraseOp (op);
@@ -630,7 +634,7 @@ struct SgUpdateTileOffsetOpPattern
630
634
bool isLegalElementWiseOp (mlir::Operation *op) {
631
635
auto res = op->getResult (0 );
632
636
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType ());
633
- if (resType.getRank () != 2 )
637
+ if (resType && resType .getRank () != 2 )
634
638
return false ;
635
639
return true ;
636
640
}
0 commit comments