Skip to content

Commit fc37818

Browse files
authored
[mlir][xegpu] Minor fixes in XeGPU subgroup distribution. (#147846)
This PR addresses the following issues. 1. Add the missing attributes when creating a new GPU funcOp in `MoveFuncBodyToWarpExecuteOnLane0` pattern. 2. Bug fix in LoadNd distribution to make sure LoadOp is the last op in warpOp region before it is distributed (needed for preserving the memory op ordering during distribution). 3. Add utility for removing OpOperand or OpResult layout attributes.
1 parent 72a2d82 commit fc37818

File tree

4 files changed

+110
-74
lines changed

4 files changed

+110
-74
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ LayoutAttr getLayoutAttr(const Value value);
7676
/// it will check the operand itself and its defining op.
7777
LayoutAttr getLayoutAttr(const OpOperand &opr);
7878

79+
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
80+
template <typename T,
81+
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
82+
std::is_same_v<T, OpResult>>>
83+
void removeLayoutAttr(const T &operandOrResult);
84+
85+
/// Removes the LayoutAttr for each OpOperand and OpResult of the given
86+
/// operation if they exist. If the operation contains regions, it is also
87+
/// applied recursively to the contained operations
88+
void removeLayoutAttrs(Operation *op);
89+
7990
/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
8091
/// it to the owner's dictionary attributes
8192
template <typename T,

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,6 @@ static Value resolveDistributedTy(Value orig, T expected,
135135
return orig;
136136
}
137137

138-
/// Helper function to filter out the temporary layout attributes attached
139-
/// during the layout assignment process. These are not needed after going to
140-
/// SIMT.
141-
static SmallVector<NamedAttribute>
142-
removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
143-
SmallVector<NamedAttribute> newAttrs;
144-
for (NamedAttribute attr : attrs) {
145-
if (!isa<xegpu::LayoutAttr>(attr.getValue()))
146-
newAttrs.push_back(attr);
147-
}
148-
return newAttrs;
149-
}
150-
151138
/// Helper function to check if the layout is packed. Layout is packed if it is
152139
/// 2D and lane_data[0] != 1 (data packed from col dimension).
153140
static bool hasPackedLayout(xegpu::LayoutAttr layout) {
@@ -197,9 +184,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
197184
return isa<gpu::WarpExecuteOnLane0Op>(op);
198185
}))
199186
return failure();
200-
// Create a new function with the same signature.
187+
// Create a new function with the same signature and same attributes.
188+
SmallVector<Type> workgroupAttributionsTypes =
189+
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
190+
[](BlockArgument arg) { return arg.getType(); });
191+
SmallVector<Type> privateAttributionsTypes =
192+
llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
193+
[](BlockArgument arg) { return arg.getType(); });
201194
auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
202-
gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
195+
gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
196+
workgroupAttributionsTypes, privateAttributionsTypes);
197+
newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
203198
// Create a WarpExecuteOnLane0Op with same arguments and results as the
204199
// original gpuFuncOp.
205200
rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
@@ -265,13 +260,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
265260
/// ```
266261
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
267262
using gpu::WarpDistributionPattern::WarpDistributionPattern;
268-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
263+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
269264
PatternRewriter &rewriter) const override {
270265
OpOperand *operand =
271-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
266+
getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
272267
if (!operand)
273268
return rewriter.notifyMatchFailure(
274-
subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
269+
warpOp, "warp result is not a xegpu::CreateNdDesc op");
275270
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
276271
unsigned operandIdx = operand->getOperandNumber();
277272

@@ -288,9 +283,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
288283
newYieldValues.push_back(operand);
289284
newYieldTypes.push_back(operand.getType());
290285
}
291-
rewriter.setInsertionPoint(subgroupOp);
286+
rewriter.setInsertionPoint(warpOp);
292287
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
293-
rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
288+
rewriter, warpOp, /* new yieled values = */ newYieldValues,
294289
/* new yielded types = */ newYieldTypes, newRetIndices);
295290

296291
SmallVector<Value> newDescOperands;
@@ -347,10 +342,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
347342
/// ```
348343
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
349344
using gpu::WarpDistributionPattern::WarpDistributionPattern;
350-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
345+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
351346
PatternRewriter &rewriter) const override {
352347
auto yield = cast<gpu::YieldOp>(
353-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
348+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
354349
Operation *lastNode = yield->getPrevNode();
355350
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
356351
if (!storeOp)
@@ -372,7 +367,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
372367

373368
SmallVector<size_t> newRetIndices;
374369
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
375-
rewriter, subgroupOp,
370+
rewriter, warpOp,
376371
/* new yielded values = */
377372
ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
378373
/* new yielded types = */
@@ -403,9 +398,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
403398
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
404399
distributedTensorDescTy, rewriter));
405400

406-
rewriter.create<xegpu::StoreNdOp>(
407-
newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
408-
removeTemporaryLayoutAttributes(storeOp->getAttrs()));
401+
auto newStoreOp = rewriter.create<xegpu::StoreNdOp>(
402+
newWarpOp.getLoc(), TypeRange{}, newStoreOperands, storeOp->getAttrs());
403+
xegpu::removeLayoutAttrs(newStoreOp);
409404
rewriter.eraseOp(storeOp);
410405
return success();
411406
}
@@ -449,21 +444,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
449444
/// ```
450445
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
451446
using gpu::WarpDistributionPattern::WarpDistributionPattern;
452-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
447+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
453448
PatternRewriter &rewriter) const override {
454-
OpOperand *operand =
455-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
449+
OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
450+
if (!isa<xegpu::LoadNdOp>(op))
451+
return false;
452+
// Make sure the same load op is the last operation in the warp op body.
453+
// This ensure that load op is not sinked earlier violating any barrier
454+
// synchronizations.
455+
auto yield = cast<gpu::YieldOp>(
456+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
457+
return yield->getPrevNode() == op;
458+
});
459+
456460
if (!operand)
457461
return rewriter.notifyMatchFailure(
458-
subgroupOp, "warp result is not a xegpu::LoadNd op");
459-
// Make sure the load op is the last operation in the warp op body. This
460-
// ensure that load op is not sinked earlier violating any barrier
461-
// synchronizations.
462-
auto yield = cast<gpu::YieldOp>(
463-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
464-
Operation *lastNode = yield->getPrevNode();
465-
if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
466-
return failure();
462+
warpOp, "warp result is not a xegpu::LoadNd op");
467463

468464
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
469465
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -474,11 +470,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
474470

475471
unsigned operandIdx = operand->getOperandNumber();
476472
VectorType distributedTypeByWarpOp =
477-
cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
473+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
478474

479475
SmallVector<size_t> newRetIndices;
480476
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
481-
rewriter, subgroupOp,
477+
rewriter, warpOp,
482478
/* new yielded values = */ loadOp.getTensorDesc(),
483479
/* new yielded types = */ tensorDescTy, newRetIndices);
484480

@@ -498,7 +494,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
498494
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
499495
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
500496
distributedTensorDescTy, rewriter),
501-
removeTemporaryLayoutAttributes(loadOp->getAttrs()));
497+
loadOp->getAttrs());
498+
xegpu::removeLayoutAttrs(newLoadOp);
502499
// Set the packed attribute if the layout requires it.
503500
newLoadOp.setPacked(hasPackedLayout(layout));
504501
Value distributedVal = newWarpOp.getResult(operandIdx);
@@ -548,12 +545,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
548545
/// ```
549546
struct DpasDistribution final : public gpu::WarpDistributionPattern {
550547
using gpu::WarpDistributionPattern::WarpDistributionPattern;
551-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
548+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
552549
PatternRewriter &rewriter) const override {
553-
OpOperand *operand =
554-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
550+
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
555551
if (!operand)
556-
return rewriter.notifyMatchFailure(subgroupOp,
552+
return rewriter.notifyMatchFailure(warpOp,
557553
"warp result is not a xegpu::Dpas op");
558554

559555
auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
@@ -599,7 +595,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
599595
// Create a new warp op without the dpas.
600596
SmallVector<size_t> newRetIndices;
601597
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
602-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
598+
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
603599

604600
FailureOr<VectorType> expectedDistLhsTyOrFailure =
605601
xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
@@ -630,14 +626,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
630626
resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
631627
newDpasOperandExpectedTypes[i], rewriter));
632628
}
633-
Value newDpasOp = rewriter.create<xegpu::DpasOp>(
634-
newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
635-
removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
629+
auto newDpasOp =
630+
rewriter.create<xegpu::DpasOp>(newWarpOp->getLoc(), distributedResultTy,
631+
newDpasOperands, dpasOp->getAttrs());
632+
xegpu::removeLayoutAttrs(newDpasOp);
636633
Value distributedVal = newWarpOp.getResult(operandIdx);
637634
// Resolve the output type.
638-
newDpasOp = resolveDistributedTy(
639-
newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
640-
rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
635+
Value typeResolved =
636+
resolveDistributedTy(newDpasOp.getResult(),
637+
distResultTypeByWarpOpOrFailure.value(), rewriter);
638+
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
641639
return success();
642640
}
643641
};
@@ -678,13 +676,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
678676
/// ```
679677
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
680678
using gpu::WarpDistributionPattern::WarpDistributionPattern;
681-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
679+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
682680
PatternRewriter &rewriter) const override {
683681
OpOperand *operand =
684-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
682+
getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685683
if (!operand)
686684
return rewriter.notifyMatchFailure(
687-
subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
685+
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
688686
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
689687
unsigned operandIdx = operand->getOperandNumber();
690688
// new update op does not have layout attribute.
@@ -703,7 +701,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
703701
}
704702
SmallVector<size_t> newRetIndices;
705703
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
706-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
704+
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
707705
rewriter.setInsertionPointAfter(newWarpOp);
708706
SmallVector<Value> newUpdateOperands;
709707
for (size_t i : newRetIndices) {
@@ -717,14 +715,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
717715
}
718716
}
719717
// Create a new update op outside the warp op.
720-
Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
718+
auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
721719
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
722-
removeTemporaryLayoutAttributes(updateOp->getAttrs()));
720+
updateOp->getAttrs());
721+
xegpu::removeLayoutAttrs(newUpdateOp);
723722
Value distributedVal = newWarpOp.getResult(operandIdx);
724723
// Resolve the distributed type with the original type.
725-
newUpdateOp =
726-
resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
727-
rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
724+
Value typeResolved = resolveDistributedTy(
725+
newUpdateOp.getResult(), distributedVal.getType(), rewriter);
726+
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
728727
return success();
729728
}
730729
};
@@ -758,10 +757,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
758757
/// ```
759758
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
760759
using gpu::WarpDistributionPattern::WarpDistributionPattern;
761-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
760+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
762761
PatternRewriter &rewriter) const override {
763762
auto yield = cast<gpu::YieldOp>(
764-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
763+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
765764
Operation *lastNode = yield->getPrevNode();
766765
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
767766
if (!prefetchOp)
@@ -775,17 +774,18 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
775774
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
776775
SmallVector<size_t> newRetIndices;
777776
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
778-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
777+
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
779778
// Create a new prefetch op outside the warp op with updated tensor
780779
// descriptor type. Source tensor descriptor require type resolution.
781780
xegpu::TensorDescType newTensorDescTy =
782781
prefetchOp.getTensorDescType().dropLayouts();
783782
rewriter.setInsertionPointAfter(newWarpOp);
784783
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
785784
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
786-
rewriter.create<xegpu::PrefetchNdOp>(
787-
newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
788-
removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
785+
rewriter.create<xegpu::PrefetchNdOp>(newWarpOp.getLoc(), TypeRange{},
786+
newPrefetchOperands,
787+
prefetchOp->getAttrs());
788+
xegpu::removeLayoutAttrs(prefetchOp);
789789
rewriter.eraseOp(prefetchOp);
790790
return success();
791791
}
@@ -795,17 +795,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
795795
/// region. This will simply move the barrier op outside of the warp op.
796796
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
797797
using gpu::WarpDistributionPattern::WarpDistributionPattern;
798-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
798+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
799799
PatternRewriter &rewriter) const override {
800800
auto yield = cast<gpu::YieldOp>(
801-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
801+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
802802
Operation *lastNode = yield->getPrevNode();
803803
// The last node must be a gpu::BarrierOp.
804804
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
805805
if (!barrierOp)
806806
return failure();
807807
// Move the barrier op outside of the warp op.
808-
rewriter.setInsertionPointAfter(subgroupOp);
808+
rewriter.setInsertionPointAfter(warpOp);
809809
rewriter.create<gpu::BarrierOp>(
810810
barrierOp.getLoc(), barrierOp->getResultTypes(),
811811
barrierOp->getOperands(), barrierOp->getAttrs());

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,31 @@ void xegpu::setLayoutAttrs(Operation *op,
184184
});
185185
}
186186

187+
template <typename T, typename>
188+
void xegpu::removeLayoutAttr(const T &operandOrResult) {
189+
Operation *owner = operandOrResult.getOwner();
190+
std::string name = xegpu::getLayoutName(operandOrResult);
191+
if (owner->hasAttrOfType<LayoutAttr>(name))
192+
owner->removeAttr(name);
193+
}
194+
195+
// Explicit instantiation for OpResult
196+
template void
197+
xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
198+
199+
// Explicit instantiation for OpOperand
200+
template void
201+
xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
202+
203+
void xegpu::removeLayoutAttrs(Operation *op) {
204+
op->walk([&](Operation *nestOp) {
205+
for (OpOperand &opr : nestOp->getOpOperands())
206+
removeLayoutAttr(opr);
207+
for (OpResult result : nestOp->getOpResults())
208+
removeLayoutAttr(result);
209+
});
210+
}
211+
187212
SmallVector<Value>
188213
xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
189214
Value value, ArrayRef<int64_t> shape) {

0 commit comments

Comments
 (0)