@@ -301,6 +301,10 @@ class LayoutInfoPropagation
301301 ArrayRef<LayoutInfoLattice *> operands,
302302 ArrayRef<const LayoutInfoLattice *> results);
303303
304+ void visitPrefetchNdOp (xegpu::PrefetchNdOp prefetch,
305+ ArrayRef<LayoutInfoLattice *> operands,
306+ ArrayRef<const LayoutInfoLattice *> results);
307+
304308 void visitVectorMultiReductionOp (vector::MultiDimReductionOp reduction,
305309 ArrayRef<LayoutInfoLattice *> operands,
306310 ArrayRef<const LayoutInfoLattice *> results);
@@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
352356 .Case <xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
353357 visitUpdateNdOffsetOp (updateNdOffsetOp, operands, results);
354358 })
359+ .Case <xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
360+ visitPrefetchNdOp (prefetchNdOp, operands, results);
361+ })
355362 // No need to propagate the layout to operands in CreateNdDescOp because
356363 // they are scalars (offsets, sizes, etc.).
357364 .Case <xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
381388 return success ();
382389}
383390
391+ void LayoutInfoPropagation::visitPrefetchNdOp (
392+ xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
393+ ArrayRef<const LayoutInfoLattice *> results) {
394+ // Here we assign the default layout to the tensor descriptor operand of
395+ // prefetch.
396+ auto tdescTy = prefetch.getTensorDescType ();
397+ auto prefetchLayout = getDefaultLayoutInfo (
398+ VectorType::get (tdescTy.getShape (), tdescTy.getElementType ()));
399+ // Propagate the layout to the source tensor descriptor.
400+ propagateIfChanged (operands[0 ], operands[0 ]->meet (prefetchLayout));
401+ }
402+
384403void LayoutInfoPropagation::visitVectorMultiReductionOp (
385404 vector::MultiDimReductionOp reduction,
386405 ArrayRef<LayoutInfoLattice *> operands,
@@ -865,18 +884,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
865884 return VectorType::get (distributedShape, originalType.getElementType ());
866885}
867886
868- // Drop the layout attribute from the tensor descriptor type if layout is
869- // present.
870- static xegpu::TensorDescType dropLayouts (xegpu::TensorDescType tensorDesc) {
871- if (tensorDesc.getLayoutAttr () == xegpu::LayoutAttr ())
872- return tensorDesc;
873-
874- return xegpu::TensorDescType::get (
875- tensorDesc.getContext (), tensorDesc.getShape (),
876- tensorDesc.getElementType (), tensorDesc.getEncoding (),
877- xegpu::LayoutAttr ());
878- }
879-
880887// / Helper function to resolve types if the distributed type out of
881888// / gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
882889// / Example 1:
@@ -1023,12 +1030,12 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10231030// / Example:
10241031// /
10251032// / ```
1026- // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1033+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
10271034// / %r = gpu.warp_execute_on_lane_0(%laneid) ->
1028- // / (!xegpu.tensor_desc<4x8xf32, #lo0 >) {
1035+ // / (!xegpu.tensor_desc<4x8xf32, #layout0 >) {
10291036// / ...
10301037// / %td = xegpu.create_nd_tdesc %arg0[0, 0]
1031- // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0 >
1038+ // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0 >
10321039// / vector.yield %td
10331040// / }
10341041// / ```
@@ -1037,7 +1044,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10371044// / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
10381045// / ...
10391046// / %dead = xegpu.create_nd_tdesc %arg0[0, 0]
1040- // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0 >
1047+ // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0 >
10411048// / vector.yield %arg0, %dead
10421049// / }
10431050// / %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
@@ -1080,8 +1087,8 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
10801087 }
10811088 rewriter.setInsertionPointAfter (newWarpOp);
10821089 xegpu::TensorDescType distributedTensorDescTy =
1083- dropLayouts ( descOp.getType ()); // Distributed tensor descriptor type
1084- // does not contain layout info.
1090+ descOp.getType (). dropLayouts ( ); // Distributed tensor descriptor type
1091+ // does not contain layout info.
10851092 auto newDescOp = rewriter.create <xegpu::CreateNdDescOp>(
10861093 newWarpOp.getLoc (), distributedTensorDescTy, newDescOperands,
10871094 descOp->getAttrs ());
@@ -1101,23 +1108,23 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
11011108// / Example:
11021109// /
11031110// / ```
1104- // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1111+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
11051112// / gpu.warp_execute_on_lane_0(%laneid) -> () {
11061113// / ...
11071114// / xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
1108- // / !xegpu.tensor_desc<4x8xf32, #lo0 >
1115+ // / !xegpu.tensor_desc<4x8xf32, #layout0 >
11091116// / }
11101117// / ```
11111118// / To
11121119// / ```
11131120// / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1114- // / !xegpu.tensor_desc<4x8xf32, #lo0 >) {
1121+ // / !xegpu.tensor_desc<4x8xf32, #layout0 >) {
11151122// / gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
1116- // / #lo0 >
1123+ // / #layout0 >
11171124// / }
11181125// / %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
11191126// / %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1120- // / #lo0 >
1127+ // / #layout0 >
11211128// / -> !xegpu.tensor_desc<4x8xf32>
11221129// / xegpu.store_nd %0, %1: vector<4xf32>,
11231130// / !xegpu.tensor_desc<4x8xf32>
@@ -1173,10 +1180,10 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
11731180 newStoreOperands.push_back (resolveDistributedTy (
11741181 newWarpOp.getResult (newRetIndices[0 ]),
11751182 storeNdDistributedValueTyOrFailure.value (), rewriter));
1176- // For the tensor descriptor operand, the layout attibute is dropped after
1183+ // For the tensor descriptor operand, the layout attribute is dropped after
11771184 // distribution. Types needs to be resolved in this case also.
11781185 xegpu::TensorDescType distributedTensorDescTy =
1179- dropLayouts ( storeOp.getTensorDescType ());
1186+ storeOp.getTensorDescType (). dropLayouts ( );
11801187 newStoreOperands.push_back (
11811188 resolveDistributedTy (newWarpOp.getResult (newRetIndices[1 ]),
11821189 distributedTensorDescTy, rewriter));
@@ -1201,25 +1208,26 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
12011208// / Example:
12021209// /
12031210// / ```
1204- // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1211+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
12051212// / %r = gpu.warp_execute_on_lane_0(%laneid) ->
12061213// / (vector<4x1xf32>) {
12071214// / ...
1208- // / %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #lo0> ->
1215+ // / %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
1216+ // / ->
12091217// / vector<4x8xf32>
12101218// / gpu.yield %ld
12111219// / }
12121220// / ```
12131221// / To
12141222// / ```
12151223// / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1216- // / !xegpu.tensor_desc<4x8xf32, #lo0 >) {
1224+ // / !xegpu.tensor_desc<4x8xf32, #layout0 >) {
12171225// / ...
1218- // / %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #lo0 > ->
1226+ // / %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0 > ->
12191227// / vector<4x8xf32> gpu.yield %dead, %arg0
12201228// / }
12211229// / %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1222- // / #lo0 > -> !xegpu.tensor_desc<4x8xf32>
1230+ // / #layout0 > -> !xegpu.tensor_desc<4x8xf32>
12231231// / %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
12241232// / %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
12251233// /
@@ -1260,9 +1268,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
12601268 return rewriter.notifyMatchFailure (
12611269 loadOp, " Failed to get distributed vector type for the load op" );
12621270 xegpu::TensorDescType distributedTensorDescTy =
1263- dropLayouts ( loadOp.getTensorDescType ()); // Distributed tensor
1264- // descriptor type does not
1265- // contain layout info.
1271+ loadOp.getTensorDescType (). dropLayouts ( ); // Distributed tensor
1272+ // descriptor type does not
1273+ // contain layout info.
12661274 auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
12671275 newWarpOp.getLoc (), loadNdDistValueTyOrFailure.value (),
12681276 resolveDistributedTy (newWarpOp->getResult (newRetIndices[0 ]),
@@ -1412,6 +1420,152 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14121420 }
14131421};
14141422
1423+ // / Sink an update_nd_offset op feeding into yield op of an enclosing
1424+ // / `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1425+ // / original op that will not be used by the yield op (and should be cleaned
1426+ // / up later). The yield op will bypass the updateOp's arguments. The tensor
1427+ // / descriptor type is not distributed. Appropriate cast ops are inserted if
1428+ // / the distributed types does not match expected xegpu SIMT types.
1429+ // / Example:
1430+ // / ```
1431+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1432+ // / %r = gpu.warp_execute_on_lane_0(%laneid) ->
1433+ // / (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1434+ // / ...
1435+ // / %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1436+ // / !xegpu.tensor_desc<4x8xf32, #layout0>
1437+ // / gpu.yield %update
1438+ // / }
1439+ // / ...
1440+ // / ```
1441+ // / To
1442+ // / ```
1443+ // / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1444+ // / !xegpu.tensor_desc<4x8xf32, #layout0>,
1445+ // / !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
1446+ // / ...
1447+ // / %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1448+ // / !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
1449+ // / gpu.yield %dead, %arg0, %c32, %c16
1450+ // / }
1451+ // / %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1452+ // / #layout0> -> !xegpu.tensor_desc<4x8xf32>
1453+ // / %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
1454+ // / !xegpu.tensor_desc<4x8xf32>
1455+ // / ...
1456+ // / ```
1457+ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1458+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1459+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1460+ PatternRewriter &rewriter) const override {
1461+ OpOperand *operand =
1462+ getWarpResult (subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1463+ if (!operand)
1464+ return rewriter.notifyMatchFailure (
1465+ subgroupOp, " warp result is not a xegpu::UpdateNdOffset op" );
1466+ auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
1467+ unsigned operandIdx = operand->getOperandNumber ();
1468+ // new update op does not have layout attribute.
1469+ xegpu::TensorDescType newTensorDescTy =
1470+ updateOp.getTensorDescType ().dropLayouts ();
1471+
1472+ SmallVector<Value, 3 > newYieldValues;
1473+ SmallVector<Type, 3 > newYieldTypes;
1474+ for (Value operand : updateOp->getOperands ()) {
1475+ newYieldValues.push_back (operand);
1476+ if (isa<xegpu::TensorDescType>(operand.getType ())) {
1477+ newYieldTypes.push_back (newTensorDescTy);
1478+ } else {
1479+ newYieldTypes.push_back (operand.getType ());
1480+ }
1481+ }
1482+ SmallVector<size_t > newRetIndices;
1483+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1484+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1485+ rewriter.setInsertionPointAfter (newWarpOp);
1486+ SmallVector<Value> newUpdateOperands;
1487+ for (size_t i : newRetIndices) {
1488+ // For the tensor descriptor operand, the layout attribute is dropped
1489+ // after distribution. Types needs to be resolved in this case.
1490+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
1491+ newUpdateOperands.push_back (resolveDistributedTy (
1492+ newWarpOp.getResult (i), newTensorDescTy, rewriter));
1493+ } else {
1494+ newUpdateOperands.push_back (newWarpOp.getResult (i));
1495+ }
1496+ }
1497+ // Create a new update op outside the warp op.
1498+ auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
1499+ newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
1500+ removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
1501+ Value distributedVal = newWarpOp.getResult (operandIdx);
1502+ rewriter.replaceAllUsesWith (distributedVal, newUpdateOp);
1503+ return success ();
1504+ }
1505+ };
1506+
1507+ // / Distribute a prefetch_nd op at the end of enclosing
1508+ // / `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1509+ // / through the warp op interface they would be propagated as returned values.
1510+ // / Tensor descriptor shape is not distributed because it is a uniform value
1511+ // / across all work items within the subgroup. Appropriate cast ops are inserted
1512+ // / if the distributed types does not match expected xegpu SIMT types.
1513+ // /
1514+ // / Example:
1515+ // /
1516+ // / ```
1517+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1518+ // / gpu.warp_execute_on_lane_0(%laneid) -> () {
1519+ // / ...
1520+ // / xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
1521+ // / }
1522+ // / ```
1523+ // / To
1524+ // / ```
1525+ // / %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
1526+ // / !xegpu.tensor_desc<4x8xf32, #layout0>) {
1527+ // / gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
1528+ // / }
1529+ // / %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1530+ // / #layout0> -> !xegpu.tensor_desc<4x8xf32>
1531+ // / xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
1532+ // /
1533+ // / ```
1534+ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1535+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1536+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1537+ PatternRewriter &rewriter) const override {
1538+ auto yield = cast<gpu::YieldOp>(
1539+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1540+ Operation *lastNode = yield->getPrevNode ();
1541+ auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1542+ if (!prefetchOp)
1543+ return failure ();
1544+ xegpu::LayoutAttr layout = prefetchOp.getTensorDescType ().getLayoutAttr ();
1545+ if (!layout)
1546+ return rewriter.notifyMatchFailure (
1547+ prefetchOp, " the source tensor descriptor lacks layout attribute" );
1548+
1549+ SmallVector<Value, 1 > newYieldValues = {prefetchOp.getTensorDesc ()};
1550+ SmallVector<Type, 1 > newYieldTypes = {prefetchOp.getTensorDescType ()};
1551+ SmallVector<size_t > newRetIndices;
1552+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1553+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1554+ // Create a new prefetch op outside the warp op with updated tensor
1555+ // descriptor type. Source tensor descriptor require type resolution.
1556+ xegpu::TensorDescType newTensorDescTy =
1557+ prefetchOp.getTensorDescType ().dropLayouts ();
1558+ rewriter.setInsertionPointAfter (newWarpOp);
1559+ SmallVector<Value> newPrefetchOperands = {resolveDistributedTy (
1560+ newWarpOp.getResult (newRetIndices[0 ]), newTensorDescTy, rewriter)};
1561+ rewriter.create <xegpu::PrefetchNdOp>(
1562+ newWarpOp.getLoc (), TypeRange{}, newPrefetchOperands,
1563+ removeTemporaryLayoutAttributes (prefetchOp->getAttrs ()));
1564+ rewriter.eraseOp (prefetchOp);
1565+ return success ();
1566+ }
1567+ };
1568+
14151569} // namespace
14161570
14171571namespace {
@@ -1430,7 +1584,8 @@ struct XeGPUSubgroupDistributePass final
14301584void xegpu::populateXeGPUSubgroupDistributePatterns (
14311585 RewritePatternSet &patterns) {
14321586 patterns.add <CreateNdDescDistribution, StoreNdDistribution,
1433- LoadNdDistribution, DpasDistribution>(patterns.getContext ());
1587+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1588+ UpdateNdOffsetDistribution>(patterns.getContext ());
14341589}
14351590
14361591void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments