Skip to content

Commit ee555d4

Browse files
committed
add tests
1 parent a76de60 commit ee555d4

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
11921192
newStoreOperands.push_back(resolveDistributedTy(
11931193
newWarpOp.getResult(newRetIndices[0]),
11941194
storeNdDistributedValueTyOrFailure.value(), rewriter));
1195-
// For the tensor descriptor operand, the layout attibute is dropped after
1195+
// For the tensor descriptor operand, the layout attribute is dropped after
11961196
// distribution. Types needs to be resolved in this case also.
11971197
xegpu::TensorDescType distributedTensorDescTy =
11981198
dropLayouts(storeOp.getTensorDescType());
@@ -1444,7 +1444,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14441444
/// (!xegpu.tensor_desc<4x8xf32, #lo0>) {
14451445
/// ...
14461446
/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1447-
/// !xegpu.tensor_desc<4x8xf32, #lo0>
1447+
/// !xegpu.tensor_desc<4x8xf32, #lo0>
14481448
/// gpu.yield %update
14491449
/// }
14501450
/// ...
@@ -1455,7 +1455,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14551455
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
14561456
/// ...
14571457
/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1458-
/// !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
1458+
/// !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
14591459
/// gup.yield %dead, %arg0, %c32, %c16
14601460
/// }
14611461
/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
@@ -1475,6 +1475,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
14751475
subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
14761476
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
14771477
unsigned operandIdx = operand->getOperandNumber();
1478+
// new update op does not have layout attribute.
14781479
xegpu::TensorDescType newTensorDescTy =
14791480
dropLayouts(updateOp.getTensorDescType());
14801481

@@ -1494,13 +1495,16 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
14941495
rewriter.setInsertionPointAfter(newWarpOp);
14951496
SmallVector<Value> newUpdateOperands;
14961497
for (size_t i : newRetIndices) {
1498+
// For the tensor descriptor operand, the layout attribute is dropped
1499+
// after distribution. Types needs to be resolved in this case.
14971500
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
14981501
newUpdateOperands.push_back(resolveDistributedTy(
14991502
newWarpOp.getResult(i), newTensorDescTy, rewriter));
15001503
} else {
15011504
newUpdateOperands.push_back(newWarpOp.getResult(i));
15021505
}
15031506
}
1507+
// Create a new update op outside the warp op.
15041508
auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
15051509
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
15061510
removeTemporaryLayoutAttributes(updateOp->getAttrs()));
@@ -1510,6 +1514,32 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
15101514
}
15111515
};
15121516

1517+
/// Distribute a prefetch_nd op at the end of enclosing
1518+
/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1519+
/// through the warp op interface they would be propagated as returned values.
1520+
/// Appropriate cast ops are inserted if the distributed types does not match
1521+
/// expected xegpu SIMT types.
1522+
///
1523+
/// Example:
1524+
///
1525+
/// ```
1526+
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1527+
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
1528+
/// ...
1529+
/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #lo0>
1530+
/// }
1531+
/// ```
1532+
/// To
1533+
/// ```
1534+
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1535+
// !xegpu.tensor_desc<4x8xf32, #lo0>) {
1536+
/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #lo0>
1537+
/// }
1538+
/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1539+
/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
1540+
/// xegpu.prefetch_nd %0 : !xegpu.tensor_desc<4x8xf32>
1541+
///
1542+
/// ```
15131543
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
15141544
using gpu::WarpDistributionPattern::WarpDistributionPattern;
15151545
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
@@ -1530,7 +1560,8 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
15301560
SmallVector<size_t> newRetIndices;
15311561
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
15321562
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1533-
1563+
// Create a new prefetch op outside the warp op with updated tensor
1564+
// descriptor type. Source tensor descriptor require type resolution.
15341565
xegpu::TensorDescType newTensorDescTy =
15351566
dropLayouts(prefetchOp.getTensorDescType());
15361567
rewriter.setInsertionPointAfter(newWarpOp);

0 commit comments

Comments
 (0)