|
29 | 29 | #include "mlir/IR/Value.h" |
30 | 30 | #include "mlir/IR/Visitors.h" |
31 | 31 | #include "mlir/Interfaces/FunctionInterfaces.h" |
32 | | -#include "mlir/Interfaces/InferTypeOpInterface.h" |
33 | 32 | #include "mlir/Support/LLVM.h" |
34 | 33 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
35 | 34 | #include "mlir/Transforms/InliningUtils.h" |
@@ -300,6 +299,9 @@ class LayoutInfoPropagation |
300 | 299 | void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset, |
301 | 300 | ArrayRef<LayoutInfoLattice *> operands, |
302 | 301 | ArrayRef<const LayoutInfoLattice *> results); |
| 302 | + void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch, |
| 303 | + ArrayRef<LayoutInfoLattice *> operands, |
| 304 | + ArrayRef<const LayoutInfoLattice *> results); |
303 | 305 |
|
304 | 306 | void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction, |
305 | 307 | ArrayRef<LayoutInfoLattice *> operands, |
@@ -352,6 +354,9 @@ LogicalResult LayoutInfoPropagation::visitOperation( |
352 | 354 | .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) { |
353 | 355 | visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results); |
354 | 356 | }) |
| 357 | + .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) { |
| 358 | + visitPrefetchNdOp(prefetchNdOp, operands, results); |
| 359 | + }) |
355 | 360 | /// No need to propagate the layout to operands in CreateNdDescOp because |
356 | 361 | /// they are scalars (offsets, sizes, etc.). |
357 | 362 | .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {}) |
@@ -400,6 +405,18 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp( |
400 | 405 | propagateIfChanged(operands[1], operands[1]->meet(resultLayout)); |
401 | 406 | } |
402 | 407 |
|
| 408 | +void LayoutInfoPropagation::visitPrefetchNdOp( |
| 409 | + xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, |
| 410 | + ArrayRef<const LayoutInfoLattice *> results) { |
| 411 | + /// Here we assign the default layout to the tensor descriptor operand of |
| 412 | + /// prefetch. |
| 413 | + auto tdescTy = prefetch.getTensorDescType(); |
| 414 | + auto prefetchLayout = getDefaultLayoutInfo( |
| 415 | + VectorType::get(tdescTy.getShape(), tdescTy.getElementType())); |
| 416 | + /// Propagate the layout to the source tensor descriptor. |
| 417 | + propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); |
| 418 | +} |
| 419 | + |
403 | 420 | /// Propagate the layout of the result tensor to the source tensor descriptor in |
404 | 421 | /// UpdateNdOffsetOp. |
405 | 422 | void LayoutInfoPropagation::visitUpdateNdOffsetOp( |
@@ -1488,6 +1505,39 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { |
1488 | 1505 | } |
1489 | 1506 | }; |
1490 | 1507 |
|
| 1508 | +struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { |
| 1509 | + using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 1510 | + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 1511 | + PatternRewriter &rewriter) const override { |
| 1512 | + auto yield = cast<gpu::YieldOp>( |
| 1513 | + subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 1514 | + Operation *lastNode = yield->getPrevNode(); |
| 1515 | + auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode); |
| 1516 | + if (!prefetchOp) |
| 1517 | + return failure(); |
| 1518 | + auto layout = prefetchOp.getTensorDescType().getLayoutAttr(); |
| 1519 | + if (!layout) |
| 1520 | + return rewriter.notifyMatchFailure( |
| 1521 | + prefetchOp, "the source tensor descriptor lacks layout attribute"); |
| 1522 | + |
| 1523 | + SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()}; |
| 1524 | + SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()}; |
| 1525 | + SmallVector<size_t> newRetIndices; |
| 1526 | + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 1527 | + rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); |
| 1528 | + |
| 1529 | + auto newTensorDescTy = dropLayouts(prefetchOp.getTensorDescType()); |
| 1530 | + rewriter.setInsertionPointAfter(newWarpOp); |
| 1531 | + SmallVector<Value> newPrefetchOperands = {resolveDistributedTy( |
| 1532 | + newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)}; |
| 1533 | + rewriter.create<xegpu::PrefetchNdOp>( |
| 1534 | + newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands, |
| 1535 | + removeTemporaryLayoutAttributes(prefetchOp->getAttrs())); |
| 1536 | + rewriter.eraseOp(prefetchOp); |
| 1537 | + return success(); |
| 1538 | + } |
| 1539 | +}; |
| 1540 | + |
1491 | 1541 | /// Generic pattern for sinking a GPU index operations feeding into yield op |
1492 | 1542 | /// of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op |
1493 | 1543 | /// becomes dead and an equivalent copy of the index op is created outside the |
@@ -1562,9 +1612,9 @@ struct XeGPUSubgroupDistributePass final |
1562 | 1612 |
|
1563 | 1613 | void xegpu::populateXeGPUSubgroupDistributePatterns( |
1564 | 1614 | RewritePatternSet &patterns) { |
1565 | | - patterns |
1566 | | - .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution, |
1567 | | - DpasDistribution, UpdateNdOffsetDistribution>(patterns.getContext()); |
| 1615 | + patterns.add<CreateNdDescDistribution, StoreNdDistribution, |
| 1616 | + LoadNdDistribution, DpasDistribution, UpdateNdOffsetDistribution, |
| 1617 | + PrefetchNdDistribution>(patterns.getContext()); |
1568 | 1618 | /// TODO: Is this the right place to add these patterns? |
1569 | 1619 | patterns.add<GpuIndexOpDistribution<gpu::BlockIdOp>, |
1570 | 1620 | GpuIndexOpDistribution<gpu::BlockDimOp>, |
|
0 commit comments