Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ LayoutAttr getLayoutAttr(const Value value);
/// it will check the operand itself and its defining op.
LayoutAttr getLayoutAttr(const OpOperand &opr);

/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
void removeLayoutAttr(const T &operandOrResult);

/// Removes the LayoutAttr for each OpOperand and OpResult of the given
/// operation if they exist. If the operation contains regions, it is also
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);

/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
/// it to the owner's dictionary attributes
template <typename T,
Expand Down
140 changes: 70 additions & 70 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,6 @@ static Value resolveDistributedTy(Value orig, T expected,
return orig;
}

/// Helper function to filter out the temporary layout attributes attached
/// during the layout assignment process. These are not needed after going to
/// SIMT.
static SmallVector<NamedAttribute>
removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute> newAttrs;
for (NamedAttribute attr : attrs) {
if (!isa<xegpu::LayoutAttr>(attr.getValue()))
newAttrs.push_back(attr);
}
return newAttrs;
}

/// Helper function to check if the layout is packed. Layout is packed if it is
/// 2D and lane_data[0] != 1 (data packed from col dimension).
static bool hasPackedLayout(xegpu::LayoutAttr layout) {
Expand Down Expand Up @@ -197,9 +184,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
return isa<gpu::WarpExecuteOnLane0Op>(op);
}))
return failure();
// Create a new function with the same signature.
// Create a new function with the same signature and same attributes.
SmallVector<Type> workgroupAttributionsTypes =
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
[](BlockArgument arg) { return arg.getType(); });
SmallVector<Type> privateAttributionsTypes =
llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
[](BlockArgument arg) { return arg.getType(); });
auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
workgroupAttributionsTypes, privateAttributionsTypes);
newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
// Create a WarpExecuteOnLane0Op with same arguments and results as the
// original gpuFuncOp.
rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
Expand Down Expand Up @@ -265,13 +260,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
/// ```
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
if (!operand)
return rewriter.notifyMatchFailure(
subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
warpOp, "warp result is not a xegpu::CreateNdDesc op");
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
unsigned operandIdx = operand->getOperandNumber();

Expand All @@ -288,9 +283,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
newYieldValues.push_back(operand);
newYieldTypes.push_back(operand.getType());
}
rewriter.setInsertionPoint(subgroupOp);
rewriter.setInsertionPoint(warpOp);
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
rewriter, warpOp, /* new yieled values = */ newYieldValues,
/* new yielded types = */ newYieldTypes, newRetIndices);

SmallVector<Value> newDescOperands;
Expand Down Expand Up @@ -347,10 +342,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
Expand All @@ -372,7 +367,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp,
rewriter, warpOp,
/* new yielded values = */
ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
/* new yielded types = */
Expand Down Expand Up @@ -403,9 +398,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
distributedTensorDescTy, rewriter));

rewriter.create<xegpu::StoreNdOp>(
newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
removeTemporaryLayoutAttributes(storeOp->getAttrs()));
auto newStoreOp = rewriter.create<xegpu::StoreNdOp>(
newWarpOp.getLoc(), TypeRange{}, newStoreOperands, storeOp->getAttrs());
xegpu::removeLayoutAttrs(newStoreOp);
rewriter.eraseOp(storeOp);
return success();
}
Expand Down Expand Up @@ -449,21 +444,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
if (!isa<xegpu::LoadNdOp>(op))
return false;
// Make sure the same load op is the last operation in the warp op body.
// This ensure that load op is not sinked earlier violating any barrier
// synchronizations.
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
return yield->getPrevNode() == op;
});

if (!operand)
return rewriter.notifyMatchFailure(
subgroupOp, "warp result is not a xegpu::LoadNd op");
// Make sure the load op is the last operation in the warp op body. This
// ensure that load op is not sinked earlier violating any barrier
// synchronizations.
auto yield = cast<gpu::YieldOp>(
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
return failure();
warpOp, "warp result is not a xegpu::LoadNd op");

auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
Expand All @@ -474,11 +470,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {

unsigned operandIdx = operand->getOperandNumber();
VectorType distributedTypeByWarpOp =
cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
cast<VectorType>(warpOp.getResult(operandIdx).getType());

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp,
rewriter, warpOp,
/* new yielded values = */ loadOp.getTensorDesc(),
/* new yielded types = */ tensorDescTy, newRetIndices);

Expand All @@ -498,7 +494,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
distributedTensorDescTy, rewriter),
removeTemporaryLayoutAttributes(loadOp->getAttrs()));
loadOp->getAttrs());
xegpu::removeLayoutAttrs(newLoadOp);
// Set the packed attribute if the layout requires it.
newLoadOp.setPacked(hasPackedLayout(layout));
Value distributedVal = newWarpOp.getResult(operandIdx);
Expand Down Expand Up @@ -548,12 +545,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct DpasDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
if (!operand)
return rewriter.notifyMatchFailure(subgroupOp,
return rewriter.notifyMatchFailure(warpOp,
"warp result is not a xegpu::Dpas op");

auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
Expand Down Expand Up @@ -599,7 +595,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
// Create a new warp op without the dpas.
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);

FailureOr<VectorType> expectedDistLhsTyOrFailure =
xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
Expand Down Expand Up @@ -630,14 +626,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
newDpasOperandExpectedTypes[i], rewriter));
}
Value newDpasOp = rewriter.create<xegpu::DpasOp>(
newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
auto newDpasOp =
rewriter.create<xegpu::DpasOp>(newWarpOp->getLoc(), distributedResultTy,
newDpasOperands, dpasOp->getAttrs());
xegpu::removeLayoutAttrs(newDpasOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
// Resolve the output type.
newDpasOp = resolveDistributedTy(
newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
Value typeResolved =
resolveDistributedTy(newDpasOp.getResult(),
distResultTypeByWarpOpOrFailure.value(), rewriter);
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
return success();
}
};
Expand Down Expand Up @@ -678,13 +676,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
if (!operand)
return rewriter.notifyMatchFailure(
subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
unsigned operandIdx = operand->getOperandNumber();
// new update op does not have layout attribute.
Expand All @@ -703,7 +701,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
}
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newUpdateOperands;
for (size_t i : newRetIndices) {
Expand All @@ -717,14 +715,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
}
}
// Create a new update op outside the warp op.
Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
removeTemporaryLayoutAttributes(updateOp->getAttrs()));
updateOp->getAttrs());
xegpu::removeLayoutAttrs(newUpdateOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
// Resolve the distributed type with the original type.
newUpdateOp =
resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
Value typeResolved = resolveDistributedTy(
newUpdateOp.getResult(), distributedVal.getType(), rewriter);
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
return success();
}
};
Expand Down Expand Up @@ -758,10 +757,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
Expand All @@ -775,17 +774,18 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
// Create a new prefetch op outside the warp op with updated tensor
// descriptor type. Source tensor descriptor require type resolution.
xegpu::TensorDescType newTensorDescTy =
prefetchOp.getTensorDescType().dropLayouts();
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
rewriter.create<xegpu::PrefetchNdOp>(
newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
rewriter.create<xegpu::PrefetchNdOp>(newWarpOp.getLoc(), TypeRange{},
newPrefetchOperands,
prefetchOp->getAttrs());
xegpu::removeLayoutAttrs(prefetchOp);
rewriter.eraseOp(prefetchOp);
return success();
}
Expand All @@ -795,17 +795,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
/// region. This will simply move the barrier op outside of the warp op.
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
if (!barrierOp)
return failure();
// Move the barrier op outside of the warp op.
rewriter.setInsertionPointAfter(subgroupOp);
rewriter.setInsertionPointAfter(warpOp);
rewriter.create<gpu::BarrierOp>(
barrierOp.getLoc(), barrierOp->getResultTypes(),
barrierOp->getOperands(), barrierOp->getAttrs());
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,33 @@ void xegpu::setLayoutAttrs(Operation *op,
});
}

template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
if (owner->hasAttrOfType<LayoutAttr>(name))
owner->removeAttr(name);
}

// Explicit instantiation for OpResult
template void
xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);

// Explicit instantiation for OpOperand
template void
xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);

void xegpu::removeLayoutAttrs(Operation *op) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
removeLayoutAttr(opr);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: single statement braces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

for (OpResult result : nestOp->getOpResults()) {
removeLayoutAttr(result);
}
});
}

SmallVector<Value>
xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
Value value, ArrayRef<int64_t> shape) {
Expand Down
Loading