Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
}

/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
bool scattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
Expand All @@ -207,14 +208,23 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
if (scattered) {
packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
: 1;
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
LaneData({1, packingFactor}));
}
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
LaneData({1, packingFactor}));
}

/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
bool scattered = false) {
// Expecting a 1D or 2D vector.
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
"Expected 1D or 2D TensorDesc.");
Expand All @@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();

if (tdescTy.isScattered()) {
if (scattered) {
int packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
Expand Down Expand Up @@ -541,21 +551,27 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}

/// Propagate the layout of the result to the tensor descriptor and mask
/// Propagate the layout of the result to the tensor descriptor, mask and offset
/// operands in LoadGatherOp.
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout is strictly determined by the tensor descriptor type.
LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
// The layout is strictly determined by the payload type.
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
assert(payloadTy && "Only vector payload distribution is supported");
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);

// Mask operand should have 1D default layout.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);

// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the new layout to the mask operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets()) {
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}
}

/// Propagate the layout of the descriptor to the vector offset operand in
Expand All @@ -572,31 +588,36 @@ void LayoutInfoPropagation::visitCreateDescOp(
propagateIfChanged(operands[1], operands[1]->meet(layout));
}

/// Set the layout for the value, tensor descriptor, and mask operands in the
/// StoreScatterOp.
/// Set the layout for the value, tensor descriptor, offset and mask operands in
/// the StoreScatterOp.
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
if (tdescShape.size() > 1)
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
assert(payloadTy && "Only vector payload distribution is supported");
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
assert(
tdescShape[0] == xegpu::targetinfo::subgroupSize &&
payloadShape[0] == xegpu::targetinfo::subgroupSize &&
"Expected the first dimension of 2D tensor descriptor to be equal to "
"subgroup size.");

LayoutInfo layout =
getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
LayoutInfo payloadLayout =
getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);

// Propagate the value layout.
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the tensor descriptor layout.
propagateIfChanged(operands[1], operands[1]->meet(layout));
// Use default 1D layout for mask operand.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
if (storeScatter.getOffsets())
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}

namespace {
Expand Down
216 changes: 212 additions & 4 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,210 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};

/// Distribute a scattered store op. The offsets argument is required.
/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
/// The layouts are fixed and implicit: one offset/mask per lane.
/// The pass changes the offset/mask vector shapes to a
/// single-element vector, **it is assumed that their producer will also be
/// distributed**. The payload vector also has a fixed distribution:
/// no chunk size -> vector of one element.
/// chunk size -> vector of the innermost dimension of the SG-payload.
/// Example 1 (no chunk size):
/// %mask = producer_op : vector<16xi1>
/// %offset = producer_op : vector<16xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
/// memref<256xf16>, vector<16xindex>, vector<16xi1>
/// To
/// %mask = producer_op : vector<1xi1>
/// %offset = producer_op : vector<1xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
/// Example 2 (chunk size, same mask and offsets):
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
/// To
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
struct StoreDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
if (!storeScatterOp)
return failure();
auto offsets = storeScatterOp.getOffsets();
if (!offsets || !isa<VectorType>(offsets.getType()))
return rewriter.notifyMatchFailure(
storeScatterOp, "Store op must have a vector of offsets argument");
VectorType offsetsTy = cast<VectorType>(offsets.getType());
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
return rewriter.notifyMatchFailure(storeScatterOp,
"Expected 1D offsets and mask vector");
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
assert(storeVecTy.getRank() <= 2 &&
"Expected at most 2D result at SG level");
VectorType distStoreVecTy;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strongly suggest using getDistVecTypeBasedOnLaneLayout to get distributed type. This will also check if the SG vector type is consistent in terms of the lane layout (i.e. it is distributable to lanes), which is not done here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask/offset/payload vectors have predetermined rules for their shape and lane assignment. Their distribution is fixed at all times, so user-side layouts are redundant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that the layout is not useful here. But it is better to keep this logic in a single place. This also ensures that the layout assigned to offsets (byt propagation logic) is indeed correct.

if (storeVecTy.getRank() == 2)
distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
else // rank 1
distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
// Assume offset and mask producers will be distributed as well.
VectorType distOffsetsTy =
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
VectorType distMaskTy = VectorType::get(
{1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
std::string layoutPayloadName =
xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
std::string layoutOffsetsName =
xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
std::string layoutMaskName =
xegpu::getLayoutName(storeScatterOp->getOpOperand(3));

xegpu::LayoutAttr layoutPayload =
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
xegpu::LayoutAttr layoutOffsets =
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
xegpu::LayoutAttr layoutMask =
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);

FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
FailureOr<VectorType> distMaskByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
if (failed(distStoreVecByWarpOpOrFailure) ||
failed(distOffsetsByWarpOpOrFailure) ||
failed(distMaskByWarpOpOrFailure)) {
storeScatterOp.emitWarning(
"Some vector operands have no layouts, using defaults instead.");
}
distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);

SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
SmallVector<Type> operandTypesToYield = {
distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy};

gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });

rewriter.setInsertionPointAfter(newWarpOp);
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
rewriter.eraseOp(storeScatterOp);
return success();
}
};

/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
/// Example 1 (no chunk size):
/// %mask = producer_op : vector<16xi1>
/// %offset = producer_op : vector<16xindex>
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
/// To
/// %mask = producer_op : vector<1xi1>
/// %offset = producer_op : vector<1xindex>
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
/// Example 2 (chunk size, same mask and offsets):
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
/// To
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
struct LoadDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
// Check if the yield operand that was produced by the *last* scattered
// load op to avoid sinking it before barriers (maintain memory order).
return isa<xegpu::LoadGatherOp>(op) &&
warpOp.getTerminator()->getPrevNode() == op;
});
if (!producedByLastLoad)
return rewriter.notifyMatchFailure(
warpOp, "The last op is not xegpu::LoadGatherOp");

auto loadGatherOp =
producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
auto offsets = loadGatherOp.getOffsets();
if (!offsets || !isa<VectorType>(offsets.getType()) ||
!isa<VectorType>(loadGatherOp.getMask().getType()))
return rewriter.notifyMatchFailure(
loadGatherOp,
"Load op must have a vector arguments for offsets and mask");
VectorType offsetsTy = cast<VectorType>(offsets.getType());
VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
return rewriter.notifyMatchFailure(loadGatherOp,
"Expected 1D offsets and mask vector");
// Assume offset and mask producers will be distributed as well.
VectorType distOffsetsTy =
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy));

std::string layoutOffsetsName =
xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
std::string layoutMaskName =
xegpu::getLayoutName(loadGatherOp->getOpOperand(2));

xegpu::LayoutAttr layoutOffsets =
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
xegpu::LayoutAttr layoutMask =
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);

FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
FailureOr<VectorType> distMaskByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
if (failed(distOffsetsByWarpOpOrFailure) ||
failed(distMaskByWarpOpOrFailure)) {
loadGatherOp.emitWarning(
"Some vector operands have no layouts, using defaults instead.");
}
distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);

SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = loadGatherOp->getOperands();
SmallVector<Type> operandTypesToYield = {operands[0].getType(),
distOffsetsTy, distMaskTy};

gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);

SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });

const unsigned operandIdx = producedByLastLoad->getOperandNumber();
VectorType loadVecTy =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");

rewriter.setInsertionPointAfter(newWarpOp);
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
loadGatherOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
return success();
}
};

} // namespace

namespace {
Expand All @@ -819,10 +1023,11 @@ struct XeGPUSubgroupDistributePass final

void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
UpdateNdOffsetDistribution, GpuBarrierDistribution>(
patterns.getContext());
patterns
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
patterns.getContext());
}

void XeGPUSubgroupDistributePass::runOnOperation() {
Expand All @@ -837,6 +1042,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;

// Vectors operands of these ops have a fixed and implicit layout.
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
continue;
auto layout =
xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
if (!layout) {
Expand Down
Loading