Skip to content

Commit 11d6cbc

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents b819852 + c2824c1 commit 11d6cbc

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,18 @@ struct SgPrefetchTileOpPattern
379379
XeGPUOneToNPatterRewriter &rewriter) const override {
380380
auto tileTy = op.getTile().getType();
381381
auto tiles = adaptor.getTile();
382-
if (tileTy.getRank() != 4)
382+
auto innerBlocks = tileTy.getInnerBlocks();
383+
384+
if (tileTy.getRank() != 2)
383385
return mlir::failure();
384-
auto shape = tileTy.getShape();
385386

386-
if (shape[0] * shape[1] != (int64_t)tiles.size()) {
387+
if (!innerBlocks || innerBlocks.size() != 2)
388+
return mlir::failure();
389+
390+
auto shape = tileTy.getShape();
391+
auto expectedNumTensorDescs =
392+
(shape[0] / innerBlocks[0]) * (shape[1] / innerBlocks[1]);
393+
if (expectedNumTensorDescs != (int64_t)tiles.size()) {
387394
op.emitOpError("Failed to lower LoadTileOp because shape[0] * shape[1] "
388395
"!= sources.size().");
389396
return mlir::failure();
@@ -396,12 +403,9 @@ struct SgPrefetchTileOpPattern
396403
auto L3 = xegpu::CacheReadHintAttr::get(op.getContext(),
397404
xegpu::CacheReadHint::CACHED);
398405

399-
for (int i = 0; i < shape[0]; i++) {
400-
for (int j = 0; j < shape[1]; j++) {
401-
auto tile = tiles[i * shape[1] + j];
402-
rewriter.create<xegpu::PrefetchNDOp>(op.getLoc(), tile, L1, L2, L3,
403-
imex::xegpu::Mode::VC);
404-
}
406+
for (auto tile : tiles) {
407+
rewriter.create<xegpu::PrefetchNDOp>(op.getLoc(), tile, L1, L2, L3,
408+
imex::xegpu::Mode::VC);
405409
}
406410

407411
rewriter.eraseOp(op);
@@ -630,7 +634,7 @@ struct SgUpdateTileOffsetOpPattern
630634
bool isLegalElementWiseOp(mlir::Operation *op) {
631635
auto res = op->getResult(0);
632636
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
633-
if (resType.getRank() != 2)
637+
if (resType && resType.getRank() != 2)
634638
return false;
635639
return true;
636640
}

lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
4141
addIllegalOp<imex::xetile::InitTileOp>();
4242

4343
addLegalOp<mlir::UnrealizedConversionCastOp>();
44-
4544
addLegalOp<mlir::vector::ExtractOp>();
4645
addLegalOp<mlir::vector::ExtractStridedSliceOp>();
4746
addLegalOp<mlir::vector::ShuffleOp>();

0 commit comments

Comments
 (0)