Skip to content

Commit 519d02a

Browse files
committed
add cse for cleaning up
1 parent 08ade3f commit 519d02a

File tree

2 files changed

+8
-75
lines changed

2 files changed

+8
-75
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 5 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,11 @@ struct MoveFuncBodyToWarpExecuteOnLane0
10081008
rewriter.setInsertionPointAfter(warpOp);
10091009
rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
10101010
rewriter.replaceOp(gpuFuncOp, newGpuFunc);
1011+
// At this point, we have moved the entire function body inside the warpOp.
1012+
// Now move any scalar uniform code outside of the warpOp (like GPU index
1013+
// ops, scalar constants, etc.). This will simplify the later lowering and
1014+
// avoid custom patterns for these ops.
1015+
vector::moveScalarUniformCode(warpOp);
10111016
return success();
10121017
}
10131018
};
@@ -1412,63 +1417,6 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14121417
}
14131418
};
14141419

1415-
/// Generic pattern for sinking a GPU index operations feeding into yield op
1416-
/// of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op
1417-
/// becomes dead and an equivalent copy of the index op is created outside the
1418-
/// warp op.
1419-
/// Example:
1420-
/// ```
1421-
/// %r = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
1422-
/// ...
1423-
/// %index = gpu.block_id x : index
1424-
/// gpu.yield %index
1425-
/// }
1426-
/// ...
1427-
/// ```
1428-
/// To
1429-
/// ```
1430-
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
1431-
/// ...
1432-
/// %dead = gpu.block_id x : index
1433-
/// gpu.yield %dead
1434-
/// }
1435-
/// %0 = gpu.block_id x : index
1436-
/// ...
1437-
/// ```
1438-
template <typename IndexOp>
1439-
struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
1440-
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1441-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1442-
PatternRewriter &rewriter) const override {
1443-
OpOperand *operand = getWarpResult(subgroupOp, llvm::IsaPred<IndexOp>);
1444-
if (!operand)
1445-
return rewriter.notifyMatchFailure(subgroupOp,
1446-
"warp result is not a gpu index op");
1447-
Operation *indexOp = operand->get().getDefiningOp<IndexOp>();
1448-
unsigned operandIdx = operand->getOperandNumber();
1449-
SmallVector<Value, 3> newYieldValues;
1450-
SmallVector<Type, 3> newYieldTypes;
1451-
for (Value operand : indexOp->getOperands()) {
1452-
newYieldValues.push_back(operand);
1453-
newYieldTypes.push_back(operand.getType());
1454-
}
1455-
SmallVector<size_t> newRetIndices;
1456-
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1457-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1458-
rewriter.setInsertionPointAfter(newWarpOp);
1459-
SmallVector<Value> newIndexOperands;
1460-
for (size_t i : newRetIndices) {
1461-
newIndexOperands.push_back(newWarpOp.getResult(i));
1462-
}
1463-
auto newIndexOp = rewriter.create<IndexOp>(
1464-
newWarpOp.getLoc(), newIndexOperands,
1465-
removeTemporaryLayoutAttributes(indexOp->getAttrs()));
1466-
Value distributedVal = newWarpOp.getResult(operandIdx);
1467-
rewriter.replaceAllUsesWith(distributedVal, newIndexOp);
1468-
return success();
1469-
}
1470-
};
1471-
14721420
} // namespace
14731421

14741422
namespace {
@@ -1488,20 +1436,6 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
14881436
RewritePatternSet &patterns) {
14891437
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
14901438
LoadNdDistribution, DpasDistribution>(patterns.getContext());
1491-
// TODO: Is this the right place to add these patterns?
1492-
patterns.add<GpuIndexOpDistribution<gpu::BlockIdOp>,
1493-
GpuIndexOpDistribution<gpu::BlockDimOp>,
1494-
GpuIndexOpDistribution<gpu::SubgroupIdOp>,
1495-
GpuIndexOpDistribution<gpu::SubgroupSizeOp>,
1496-
GpuIndexOpDistribution<gpu::NumSubgroupsOp>,
1497-
GpuIndexOpDistribution<gpu::ClusterDimOp>,
1498-
GpuIndexOpDistribution<gpu::ClusterDimBlocksOp>,
1499-
GpuIndexOpDistribution<gpu::ClusterIdOp>,
1500-
GpuIndexOpDistribution<gpu::ClusterBlockIdOp>,
1501-
GpuIndexOpDistribution<gpu::GridDimOp>,
1502-
GpuIndexOpDistribution<gpu::ThreadIdOp>,
1503-
GpuIndexOpDistribution<gpu::LaneIdOp>,
1504-
GpuIndexOpDistribution<gpu::GlobalIdOp>>(patterns.getContext());
15051439
}
15061440

15071441
void XeGPUSubgroupDistributePass::runOnOperation() {

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -xegpu-subgroup-distribute -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: gpu.func @store_nd_1d
44
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
@@ -164,9 +164,9 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
164164
// -----
165165
// CHECK-LABEL: gpu.func @gemm_loop
166166
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
167+
// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
167168
// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
168169
// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
169-
// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
170170
// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
171171
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
172172
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
@@ -181,9 +181,8 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
181181
// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
182182
// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
183183
// CHECK: }
184-
// CHECK: %[[T8:.*]] = xegpu.create_nd_tdesc %[[ARG2]]{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
185184
// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
186-
// CHECK: xegpu.store_nd %[[T9]], %[[T8]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
185+
// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
187186
gpu.module @test {
188187
gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
189188
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)