@@ -301,6 +301,10 @@ class LayoutInfoPropagation
301301 ArrayRef<LayoutInfoLattice *> operands,
302302 ArrayRef<const LayoutInfoLattice *> results);
303303
304+ void visitPrefetchNdOp (xegpu::PrefetchNdOp prefetch,
305+ ArrayRef<LayoutInfoLattice *> operands,
306+ ArrayRef<const LayoutInfoLattice *> results);
307+
304308 void visitVectorMultiReductionOp (vector::MultiDimReductionOp reduction,
305309 ArrayRef<LayoutInfoLattice *> operands,
306310 ArrayRef<const LayoutInfoLattice *> results);
@@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
352356 .Case <xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
353357 visitUpdateNdOffsetOp (updateNdOffsetOp, operands, results);
354358 })
359+ .Case <xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
360+ visitPrefetchNdOp (prefetchNdOp, operands, results);
361+ })
355362 // No need to propagate the layout to operands in CreateNdDescOp because
356363 // they are scalars (offsets, sizes, etc.).
357364 .Case <xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
381388 return success ();
382389}
383390
391+ void LayoutInfoPropagation::visitPrefetchNdOp (
392+ xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
393+ ArrayRef<const LayoutInfoLattice *> results) {
394+ // Here we assign the default layout to the tensor descriptor operand of
395+ // prefetch.
396+ auto tdescTy = prefetch.getTensorDescType ();
397+ auto prefetchLayout = getDefaultLayoutInfo (
398+ VectorType::get (tdescTy.getShape (), tdescTy.getElementType ()));
399+ // Propagate the layout to the source tensor descriptor.
400+ propagateIfChanged (operands[0 ], operands[0 ]->meet (prefetchLayout));
401+ }
402+
384403void LayoutInfoPropagation::visitVectorMultiReductionOp (
385404 vector::MultiDimReductionOp reduction,
386405 ArrayRef<LayoutInfoLattice *> operands,
@@ -1412,6 +1431,174 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14121431 }
14131432};
14141433
1434+ // / Sink an update_nd_offset op feeding into yield op of an enclosing
1435+ // / `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1436+ // / original op that will not be used by the yield op (and should be cleaned
1437+ // / up later). The yield op will bypass the updateOp's arguments. The tensor
1438+ // / descriptor type is not distributed. Appropriate cast ops are inserted if
1439+ // / the distributed types does not match expected xegpu SIMT types.
1440+ // / Example:
1441+ // / ```
1442+ // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1443+ // / %r = gpu.warp_execute_on_lane_0(%laneid) ->
1444+ // / (!xegpu.tensor_desc<4x8xf32, #lo0>) {
1445+ // / ...
1446+ // / %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1447+ // / !xegpu.tensor_desc<4x8xf32, #lo0>
1448+ // / gpu.yield %update
1449+ // / }
1450+ // / ...
1451+ // / ```
1452+ // / To
1453+ // / ```
1454+ // / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1455+ // / !xegpu.tensor_desc<4x8xf32, #lo0>) {
1456+ // / ...
1457+ // / %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1458+ // / !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
1459+ // / gup.yield %dead, %arg0, %c32, %c16
1460+ // / }
1461+ // / %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1462+ // / #lo0> -> !xegpu.tensor_desc<4x8xf32>
1463+ // / %1 = xegpu.update_nd_offset %0, [%c32, %c16]:
1464+ // / !xegpu.tensor_desc<4x8xf32>
1465+ // / ...
1466+ // / ```
1467+ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1468+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1469+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1470+ PatternRewriter &rewriter) const override {
1471+ OpOperand *operand =
1472+ getWarpResult (subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1473+ if (!operand)
1474+ return rewriter.notifyMatchFailure (
1475+ subgroupOp, " warp result is not a xegpu::UpdateNdOffset op" );
1476+ auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
1477+ unsigned operandIdx = operand->getOperandNumber ();
1478+ auto newTensorDescTy = dropLayouts (updateOp.getTensorDescType ());
1479+
1480+ SmallVector<Value, 3 > newYieldValues;
1481+ SmallVector<Type, 3 > newYieldTypes;
1482+ for (auto operand : updateOp->getOperands ()) {
1483+ newYieldValues.push_back (operand);
1484+ if (isa<xegpu::TensorDescType>(operand.getType ())) {
1485+ newYieldTypes.push_back (newTensorDescTy);
1486+ } else {
1487+ newYieldTypes.push_back (operand.getType ());
1488+ }
1489+ }
1490+ SmallVector<size_t > newRetIndices;
1491+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1492+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1493+ rewriter.setInsertionPointAfter (newWarpOp);
1494+ SmallVector<Value> newUpdateOperands;
1495+ for (auto i : newRetIndices) {
1496+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
1497+ newUpdateOperands.push_back (resolveDistributedTy (
1498+ newWarpOp.getResult (i), newTensorDescTy, rewriter));
1499+ } else {
1500+ newUpdateOperands.push_back (newWarpOp.getResult (i));
1501+ }
1502+ }
1503+ auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
1504+ newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
1505+ removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
1506+ Value distributedVal = newWarpOp.getResult (operandIdx);
1507+ rewriter.replaceAllUsesWith (distributedVal, newUpdateOp);
1508+ return success ();
1509+ }
1510+ };
1511+
1512+ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1513+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1514+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1515+ PatternRewriter &rewriter) const override {
1516+ auto yield = cast<gpu::YieldOp>(
1517+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1518+ Operation *lastNode = yield->getPrevNode ();
1519+ auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1520+ if (!prefetchOp)
1521+ return failure ();
1522+ auto layout = prefetchOp.getTensorDescType ().getLayoutAttr ();
1523+ if (!layout)
1524+ return rewriter.notifyMatchFailure (
1525+ prefetchOp, " the source tensor descriptor lacks layout attribute" );
1526+
1527+ SmallVector<Value, 1 > newYieldValues = {prefetchOp.getTensorDesc ()};
1528+ SmallVector<Type, 1 > newYieldTypes = {prefetchOp.getTensorDescType ()};
1529+ SmallVector<size_t > newRetIndices;
1530+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1531+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1532+
1533+ auto newTensorDescTy = dropLayouts (prefetchOp.getTensorDescType ());
1534+ rewriter.setInsertionPointAfter (newWarpOp);
1535+ SmallVector<Value> newPrefetchOperands = {resolveDistributedTy (
1536+ newWarpOp.getResult (newRetIndices[0 ]), newTensorDescTy, rewriter)};
1537+ rewriter.create <xegpu::PrefetchNdOp>(
1538+ newWarpOp.getLoc (), TypeRange{}, newPrefetchOperands,
1539+ removeTemporaryLayoutAttributes (prefetchOp->getAttrs ()));
1540+ rewriter.eraseOp (prefetchOp);
1541+ return success ();
1542+ }
1543+ };
1544+
1545+ // / Generic pattern for sinking a GPU index operations feeding into yield op
1546+ // / of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op
1547+ // / becomes dead and an equivalent copy of the index op is created outside the
1548+ // / warp op.
1549+ // / Example:
1550+ // / ```
1551+ // / %r = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
1552+ // / ...
1553+ // / %index = gpu.block_id x : index
1554+ // / gpu.yield %index
1555+ // / }
1556+ // / ...
1557+ // / ```
1558+ // / To
1559+ // / ```
1560+ // / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
1561+ // / ...
1562+ // / %dead = gpu.block_id x : index
1563+ // / gpu.yield %dead
1564+ // / }
1565+ // / %0 = gpu.block_id x : index
1566+ // / ...
1567+ // / ```
1568+ template <typename IndexOp>
1569+ struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
1570+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1571+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1572+ PatternRewriter &rewriter) const override {
1573+ auto operand = getWarpResult (subgroupOp, llvm::IsaPred<IndexOp>);
1574+ if (!operand)
1575+ return rewriter.notifyMatchFailure (subgroupOp,
1576+ " warp result is not a gpu index op" );
1577+ auto indexOp = operand->template get ().template getDefiningOp <IndexOp>();
1578+ unsigned operandIdx = operand->template getOperandNumber ();
1579+ SmallVector<Value, 3 > newYieldValues;
1580+ SmallVector<Type, 3 > newYieldTypes;
1581+ for (auto operand : indexOp->template getOperands ()) {
1582+ newYieldValues.push_back (operand);
1583+ newYieldTypes.push_back (operand.getType ());
1584+ }
1585+ SmallVector<size_t > newRetIndices;
1586+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1587+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1588+ rewriter.setInsertionPointAfter (newWarpOp);
1589+ SmallVector<Value> newIndexOperands;
1590+ for (auto i : newRetIndices) {
1591+ newIndexOperands.push_back (newWarpOp.getResult (i));
1592+ }
1593+ auto newIndexOp = rewriter.create <IndexOp>(
1594+ newWarpOp.getLoc (), newIndexOperands,
1595+ removeTemporaryLayoutAttributes (indexOp->template getAttrs ()));
1596+ Value distributedVal = newWarpOp.getResult (operandIdx);
1597+ rewriter.replaceAllUsesWith (distributedVal, newIndexOp);
1598+ return success ();
1599+ }
1600+ };
1601+
14151602} // namespace
14161603
14171604namespace {
@@ -1430,7 +1617,22 @@ struct XeGPUSubgroupDistributePass final
14301617void xegpu::populateXeGPUSubgroupDistributePatterns (
14311618 RewritePatternSet &patterns) {
14321619 patterns.add <CreateNdDescDistribution, StoreNdDistribution,
1433- LoadNdDistribution, DpasDistribution>(patterns.getContext ());
1620+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1621+ UpdateNdOffsetDistribution>(patterns.getContext ());
1622+ // TODO: Is this the right place to add these patterns?
1623+ patterns.add <GpuIndexOpDistribution<gpu::BlockIdOp>,
1624+ GpuIndexOpDistribution<gpu::BlockDimOp>,
1625+ GpuIndexOpDistribution<gpu::SubgroupIdOp>,
1626+ GpuIndexOpDistribution<gpu::SubgroupSizeOp>,
1627+ GpuIndexOpDistribution<gpu::NumSubgroupsOp>,
1628+ GpuIndexOpDistribution<gpu::ClusterDimOp>,
1629+ GpuIndexOpDistribution<gpu::ClusterDimBlocksOp>,
1630+ GpuIndexOpDistribution<gpu::ClusterIdOp>,
1631+ GpuIndexOpDistribution<gpu::ClusterBlockIdOp>,
1632+ GpuIndexOpDistribution<gpu::GridDimOp>,
1633+ GpuIndexOpDistribution<gpu::ThreadIdOp>,
1634+ GpuIndexOpDistribution<gpu::LaneIdOp>,
1635+ GpuIndexOpDistribution<gpu::GlobalIdOp>>(patterns.getContext ());
14341636}
14351637
14361638void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments