Skip to content
66 changes: 46 additions & 20 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
}

/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
Expand All @@ -207,14 +208,23 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
if (isScattered) {
packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
: 1;
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
LaneData({1, packingFactor}));
}
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
LaneData({1, packingFactor}));
}

/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
"Expected 1D or 2D TensorDesc.");
Expand All @@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();

if (tdescTy.isScattered()) {
if (isScattered) {
int packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
Expand Down Expand Up @@ -541,21 +551,29 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}

/// Propagate the layout of the result to the tensor descriptor and mask
/// Propagate the layout of the result to the tensor descriptor, mask and offset
/// operands in LoadGatherOp.
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout is strictly determined by the tensor descriptor type.
LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
// The layout is strictly determined by the payload type.
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
if (!payloadTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);

// Mask operand should have 1D default layout.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);

// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the new layout to the mask operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets())
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}

/// Propagate the layout of the descriptor to the vector offset operand in
Expand All @@ -572,31 +590,39 @@ void LayoutInfoPropagation::visitCreateDescOp(
propagateIfChanged(operands[1], operands[1]->meet(layout));
}

/// Set the layout for the value, tensor descriptor, and mask operands in the
/// StoreScatterOp.
/// Set the layout for the value, tensor descriptor, offset and mask operands in
/// the StoreScatterOp.
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
if (tdescShape.size() > 1)
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
if (!payloadTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
assert(
tdescShape[0] == xegpu::targetinfo::subgroupSize &&
payloadShape[0] == xegpu::targetinfo::subgroupSize &&
"Expected the first dimension of 2D tensor descriptor to be equal to "
"subgroup size.");

LayoutInfo layout =
getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
LayoutInfo payloadLayout =
getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);

// Propagate the value layout.
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the tensor descriptor layout.
propagateIfChanged(operands[1], operands[1]->meet(layout));
// Use default 1D layout for mask operand.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
if (storeScatter.getOffsets())
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}

namespace {
Expand Down
203 changes: 199 additions & 4 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};

/// Distribute a scattered store op. The offsets argument is required.
/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
/// The layouts are fixed and implicit: one offset/mask per lane.
/// The pass changes the offset/mask vector shapes to a
/// single-element vector, **it is assumed that their producer will also be
/// distributed**. The payload vector also has a fixed distribution:
/// no chunk size -> vector of one element.
/// chunk size -> vector of the innermost dimension of the SG-payload.
/// Example 1 (no chunk size):
/// %mask = producer_op : vector<16xi1>
/// %offset = producer_op : vector<16xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
/// memref<256xf16>, vector<16xindex>, vector<16xi1>
/// To
/// %mask = producer_op : vector<1xi1>
/// %offset = producer_op : vector<1xindex>
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
/// Example 2 (chunk size, same mask and offsets):
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
/// To
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
struct StoreDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
if (!storeScatterOp)
return failure();
auto offsets = storeScatterOp.getOffsets();
if (!offsets || !isa<VectorType>(offsets.getType()))
return rewriter.notifyMatchFailure(
storeScatterOp, "Store op must have a vector of offsets argument");
VectorType offsetsTy = cast<VectorType>(offsets.getType());
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
return rewriter.notifyMatchFailure(storeScatterOp,
"Expected 1D offsets and mask vector");
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
if (storeVecTy.getRank() > 2)
return rewriter.notifyMatchFailure(
storeScatterOp, "Expected at most 2D result at SG level");

std::string layoutPayloadName =
xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
std::string layoutOffsetsName =
xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
std::string layoutMaskName =
xegpu::getLayoutName(storeScatterOp->getOpOperand(3));

xegpu::LayoutAttr layoutPayload =
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
xegpu::LayoutAttr layoutOffsets =
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
xegpu::LayoutAttr layoutMask =
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);

FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
FailureOr<VectorType> distMaskByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
if (failed(distStoreVecByWarpOpOrFailure) ||
failed(distOffsetsByWarpOpOrFailure) ||
failed(distMaskByWarpOpOrFailure)) {
return rewriter.notifyMatchFailure(
storeScatterOp,
"Some vector operands have no layouts, using defaults instead.");
}
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
VectorType expectedPayloadTy = VectorType::get(
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());

SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
SmallVector<Type> operandTypesToYield = {
expectedPayloadTy, operands[1].getType(),
distOffsetsByWarpOpOrFailure.value(),
distMaskByWarpOpOrFailure.value()};

gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });

rewriter.setInsertionPointAfter(newWarpOp);
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
rewriter.eraseOp(storeScatterOp);
return success();
}
};

/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
/// Example 1 (no chunk size):
/// %mask = producer_op : vector<16xi1>
/// %offset = producer_op : vector<16xindex>
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
/// To
/// %mask = producer_op : vector<1xi1>
/// %offset = producer_op : vector<1xindex>
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
/// Example 2 (chunk size, same mask and offsets):
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
/// To
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
struct LoadDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
// Check if the yield operand that was produced by the *last* scattered
// load op to avoid sinking it before barriers (maintain memory order).
return isa<xegpu::LoadGatherOp>(op) &&
warpOp.getTerminator()->getPrevNode() == op;
});
if (!producedByLastLoad)
return rewriter.notifyMatchFailure(
warpOp, "The last op is not xegpu::LoadGatherOp");

auto loadGatherOp =
producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
auto offsets = loadGatherOp.getOffsets();
if (!offsets || !isa<VectorType>(offsets.getType()) ||
!isa<VectorType>(loadGatherOp.getMask().getType()))
return rewriter.notifyMatchFailure(
loadGatherOp,
"Load op must have a vector arguments for offsets and mask");
VectorType offsetsTy = cast<VectorType>(offsets.getType());
VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
return rewriter.notifyMatchFailure(loadGatherOp,
"Expected 1D offsets and mask vector");
// Assume offset and mask producers will be distributed as well.
std::string layoutOffsetsName =
xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
std::string layoutMaskName =
xegpu::getLayoutName(loadGatherOp->getOpOperand(2));

xegpu::LayoutAttr layoutOffsets =
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
xegpu::LayoutAttr layoutMask =
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);

FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
FailureOr<VectorType> distMaskByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
if (failed(distOffsetsByWarpOpOrFailure) ||
failed(distMaskByWarpOpOrFailure)) {
return rewriter.notifyMatchFailure(
loadGatherOp,
"Some vector operands have no layouts, using defaults instead.");
}

SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = loadGatherOp->getOperands();
SmallVector<Type> operandTypesToYield = {
operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
distMaskByWarpOpOrFailure.value()};

const unsigned operandIdx = producedByLastLoad->getOperandNumber();
VectorType loadVecTy =
cast<VectorType>(warpOp.getResult(operandIdx).getType());

gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);

SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });

rewriter.setInsertionPointAfter(newWarpOp);
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
loadGatherOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
return success();
}
};

} // namespace

namespace {
Expand All @@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final

void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
UpdateNdOffsetDistribution, GpuBarrierDistribution>(
patterns.getContext());
patterns
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
patterns.getContext());
}

void XeGPUSubgroupDistributePass::runOnOperation() {
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/XeGPU/propagate-layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
return
}

// -----
// CHECK-LABEL: func.func @scatter_ops_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
return
}

// -----
// CHECK-LABEL: func.func @scatter_ops(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
%3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
return
}

// -----
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
Expand Down
Loading
Loading