Skip to content

Commit 533ddcd

Browse files
authored
[mlir][gpu] Warp execute terminator getter (#154729)
Adds a utility getter to `warp_execute_on_lane_0` which simplifies access to the op's terminator. Uses are refactored to utilize the new terminator getter.
1 parent 538e9e8 commit 533ddcd

File tree

5 files changed

+18
-22
lines changed

5 files changed

+18
-22
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3209,6 +3209,9 @@ def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0",
32093209
bool isDefinedOutsideOfRegion(Value value) {
32103210
return !getRegion().isAncestor(value.getParentRegion());
32113211
}
3212+
3213+
/// Get the terminator of the warp region.
3214+
gpu::YieldOp getTerminator();
32123215
}];
32133216
}
32143217

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
24862486
if (getArgs().size() != getWarpRegion().getNumArguments())
24872487
return emitOpError(
24882488
"expected same number op arguments and block arguments.");
2489-
auto yield =
2490-
cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2489+
gpu::YieldOp yield = getTerminator();
24912490
if (yield.getNumOperands() != getNumResults())
24922491
return emitOpError(
24932492
"expected same number of yield operands and return values.");
@@ -2511,6 +2510,10 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
25112510
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
25122511
}
25132512

2513+
gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2514+
return cast<gpu::YieldOp>(getBody()->getTerminator());
2515+
}
2516+
25142517
//===----------------------------------------------------------------------===//
25152518
// GPU KernelMetadataAttr
25162519
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
5656
SmallVector<size_t> &indices) const {
5757
SmallVector<Type> types(warpOp.getResultTypes().begin(),
5858
warpOp.getResultTypes().end());
59-
auto yield = cast<gpu::YieldOp>(
60-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
59+
gpu::YieldOp yield = warpOp.getTerminator();
6160
SmallVector<Value> yieldValues(yield.getOperands().begin(),
6261
yield.getOperands().end());
6362
llvm::SmallDenseMap<Value, unsigned> indexLookup;
@@ -89,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
8988
OpOperand *WarpDistributionPattern::getWarpResult(
9089
WarpExecuteOnLane0Op warpOp,
9190
llvm::function_ref<bool(Operation *)> fn) const {
92-
auto yield = cast<gpu::YieldOp>(
93-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
91+
gpu::YieldOp yield = warpOp.getTerminator();
9492
for (OpOperand &yieldOperand : yield->getOpOperands()) {
9593
Value yieldValues = yieldOperand.get();
9694
Operation *definedOp = yieldValues.getDefiningOp();

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
528528

529529
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
530530
PatternRewriter &rewriter) const override {
531-
auto yield = cast<gpu::YieldOp>(
532-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
531+
gpu::YieldOp yield = warpOp.getTerminator();
533532
Operation *lastNode = yield->getPrevNode();
534533
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
535534
if (!writeOp)
@@ -846,8 +845,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
846845
newYieldValues.reserve(warpOp->getNumResults());
847846
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
848847
DenseMap<OpResult, int64_t> dedupResultPositionMap;
849-
auto yield = cast<gpu::YieldOp>(
850-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
848+
gpu::YieldOp yield = warpOp.getTerminator();
851849

852850
// Some values may be yielded multiple times and correspond to multiple
853851
// results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +899,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
901899
using Base::Base;
902900
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
903901
PatternRewriter &rewriter) const override {
904-
auto yield = cast<gpu::YieldOp>(
905-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
902+
gpu::YieldOp yield = warpOp.getTerminator();
906903
Value valForwarded;
907904
unsigned resultIndex;
908905
for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1705,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17081705
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17091706
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
17101707
PatternRewriter &rewriter) const override {
1711-
auto warpOpYield = cast<gpu::YieldOp>(
1712-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1708+
gpu::YieldOp warpOpYield = warpOp.getTerminator();
17131709
// Only pick up `ForOp` if it is the last op in the region.
17141710
Operation *lastNode = warpOpYield->getPrevNode();
17151711
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
336336
using gpu::WarpDistributionPattern::WarpDistributionPattern;
337337
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
338338
PatternRewriter &rewriter) const override {
339-
auto yield = cast<gpu::YieldOp>(
340-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
339+
gpu::YieldOp yield = warpOp.getTerminator();
341340
Operation *lastNode = yield->getPrevNode();
342341
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
343342
if (!storeOp)
@@ -449,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
449448
// Make sure the same load op is the last operation in the warp op body.
450449
// This ensure that load op is not sinked earlier violating any barrier
451450
// synchronizations.
452-
auto yield = cast<gpu::YieldOp>(
453-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
451+
gpu::YieldOp yield = warpOp.getTerminator();
454452
return yield->getPrevNode() == op;
455453
});
456454

@@ -752,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
752750
using gpu::WarpDistributionPattern::WarpDistributionPattern;
753751
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
754752
PatternRewriter &rewriter) const override {
755-
auto yield = cast<gpu::YieldOp>(
756-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
753+
gpu::YieldOp yield = warpOp.getTerminator();
757754
Operation *lastNode = yield->getPrevNode();
758755
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
759756
if (!prefetchOp)
@@ -794,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
794791
using gpu::WarpDistributionPattern::WarpDistributionPattern;
795792
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
796793
PatternRewriter &rewriter) const override {
797-
auto yield = cast<gpu::YieldOp>(
798-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
794+
gpu::YieldOp yield = warpOp.getTerminator();
799795
Operation *lastNode = yield->getPrevNode();
800796
// The last node must be a gpu::BarrierOp.
801797
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);

0 commit comments

Comments
 (0)