@@ -1192,7 +1192,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
11921192 newStoreOperands.push_back (resolveDistributedTy (
11931193 newWarpOp.getResult (newRetIndices[0 ]),
11941194 storeNdDistributedValueTyOrFailure.value (), rewriter));
1195- // For the tensor descriptor operand, the layout attibute is dropped after
1195+ // For the tensor descriptor operand, the layout attribute is dropped after
11961196 // distribution. Types needs to be resolved in this case also.
11971197 xegpu::TensorDescType distributedTensorDescTy =
11981198 dropLayouts (storeOp.getTensorDescType ());
@@ -1444,7 +1444,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14441444// / (!xegpu.tensor_desc<4x8xf32, #lo0>) {
14451445// / ...
14461446// / %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1447- // / !xegpu.tensor_desc<4x8xf32, #lo0>
1447+ // / !xegpu.tensor_desc<4x8xf32, #lo0>
14481448// / gpu.yield %update
14491449// / }
14501450// / ...
@@ -1455,7 +1455,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14551455// / !xegpu.tensor_desc<4x8xf32, #lo0>) {
14561456// / ...
14571457// / %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1458- // / !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
1458+ // / !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
14591459// / gup.yield %dead, %arg0, %c32, %c16
14601460// / }
14611461// / %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
@@ -1475,6 +1475,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
14751475 subgroupOp, " warp result is not a xegpu::UpdateNdOffset op" );
14761476 auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
14771477 unsigned operandIdx = operand->getOperandNumber ();
1478+ // new update op does not have layout attribute.
14781479 xegpu::TensorDescType newTensorDescTy =
14791480 dropLayouts (updateOp.getTensorDescType ());
14801481
@@ -1494,13 +1495,16 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
14941495 rewriter.setInsertionPointAfter (newWarpOp);
14951496 SmallVector<Value> newUpdateOperands;
14961497 for (size_t i : newRetIndices) {
1498+ // For the tensor descriptor operand, the layout attribute is dropped
1499+ // after distribution. Types needs to be resolved in this case.
14971500 if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
14981501 newUpdateOperands.push_back (resolveDistributedTy (
14991502 newWarpOp.getResult (i), newTensorDescTy, rewriter));
15001503 } else {
15011504 newUpdateOperands.push_back (newWarpOp.getResult (i));
15021505 }
15031506 }
1507+ // Create a new update op outside the warp op.
15041508 auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
15051509 newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
15061510 removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
@@ -1510,6 +1514,32 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
15101514 }
15111515};
15121516
1517+ // / Distribute a prefetch_nd op at the end of enclosing
1518+ // / `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1519+ // / through the warp op interface they would be propagated as returned values.
1520+ // / Appropriate cast ops are inserted if the distributed types does not match
1521+ // / expected xegpu SIMT types.
1522+ // /
1523+ // / Example:
1524+ // /
1525+ // / ```
1526+ // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1527+ // / gpu.warp_execute_on_lane_0(%laneid) -> () {
1528+ // / ...
1529+ // / xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #lo0>
1530+ // / }
1531+ // / ```
1532+ // / To
1533+ // / ```
1534+ // / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1535+ // !xegpu.tensor_desc<4x8xf32, #lo0>) {
1536+ // / gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #lo0>
1537+ // / }
1538+ // / %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1539+ // / #lo0> -> !xegpu.tensor_desc<4x8xf32>
1540+ // / xegpu.prefetch_nd %0 : !xegpu.tensor_desc<4x8xf32>
1541+ // /
1542+ // / ```
15131543struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
15141544 using gpu::WarpDistributionPattern::WarpDistributionPattern;
15151545 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
@@ -1530,7 +1560,8 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
15301560 SmallVector<size_t > newRetIndices;
15311561 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
15321562 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1533-
1563+ // Create a new prefetch op outside the warp op with updated tensor
1564+ // descriptor type. Source tensor descriptor require type resolution.
15341565 xegpu::TensorDescType newTensorDescTy =
15351566 dropLayouts (prefetchOp.getTensorDescType ());
15361567 rewriter.setInsertionPointAfter (newWarpOp);
0 commit comments