Skip to content

Commit a4d4e66

Browse files
committed
Address feedback
1 parent 6d22968 commit a4d4e66

File tree

1 file changed

+82
-41
lines changed

1 file changed

+82
-41
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -807,26 +807,47 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807
}
808808
};
809809

810+
/// Distribute a scattered store op. The offsets argument is required.
811+
/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
812+
/// The layouts are fixed and implicit: one offset/mask per lane.
813+
/// The pass changes the offset/mask vector shapes to a
814+
/// single-element vector, **it is assumed that their producer will also be
815+
/// distributed**. The payload vector also has a fixed distribution:
816+
/// no chunk size -> vector of one element.
817+
/// chunk size -> vector of the innermost dimension of the SG-payload.
818+
/// Example 1 (no chunk size):
819+
/// %mask = producer_op : vector<16xi1>
820+
/// %offset = producer_op : vector<16xindex>
821+
/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822+
/// memref<256xf16>, vector<16xindex>, vector<16xi1>
823+
/// To
824+
/// %mask = producer_op : vector<1xi1>
825+
/// %offset = producer_op : vector<1xindex>
826+
/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827+
/// memref<256xf16>, vector<1xindex>, vector<1xi1>
828+
/// Example 2 (chunk size, same mask and offsets):
829+
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830+
/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831+
/// To
832+
/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833+
/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
810834
struct StoreDistribution final : public gpu::WarpDistributionPattern {
811835
using gpu::WarpDistributionPattern::WarpDistributionPattern;
812836
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
813837
PatternRewriter &rewriter) const override {
814-
auto yield = cast<gpu::YieldOp>(
815-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
816-
Operation *lastNode = yield->getPrevNode();
838+
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
817839
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818840
if (!storeScatterOp)
819841
return failure();
820-
if (!storeScatterOp.getOffsets())
821-
return rewriter.notifyMatchFailure(storeScatterOp,
822-
"Store op must have offsets argument");
823-
VectorType offsetsTy =
824-
cast<VectorType>(storeScatterOp.getOffsets().getType());
842+
auto offsets = storeScatterOp.getOffsets();
843+
if (!offsets || !isa<VectorType>(offsets.getType()))
844+
return rewriter.notifyMatchFailure(
845+
storeScatterOp, "Store op must have a vector of offsets argument");
846+
VectorType offsetsTy = cast<VectorType>(offsets.getType());
825847
if (offsetsTy.getRank() != 1)
826848
return rewriter.notifyMatchFailure(storeScatterOp,
827849
"Expected 1D offsets vector");
828-
VectorType storeVecTy =
829-
cast<VectorType>(storeScatterOp.getValue().getType());
850+
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
830851
assert(storeVecTy.getRank() <= 2 &&
831852
"Expected at most 2D result at SG level");
832853
VectorType distStoreVecTy;
@@ -837,80 +858,99 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
837858

838859
SmallVector<size_t> newRetIndices;
839860
SmallVector<Value> operands = storeScatterOp->getOperands();
840-
SmallVector<Type> operandTypes =
861+
SmallVector<Type> operandTypesToYield =
841862
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
842-
operandTypes[0] = distStoreVecTy;
843-
// Assume offset and mask pproducers will be distributed as well.
844-
operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
845-
operandTypes[3] = VectorType::get(
863+
operandTypesToYield[0] = distStoreVecTy;
864+
// Assume offset and mask producers will be distributed as well.
865+
operandTypesToYield[2] =
866+
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
867+
operandTypesToYield[3] = VectorType::get(
846868
{1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
847869

848870
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
849-
rewriter, warpOp, operands, operandTypes, newRetIndices);
871+
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
850872
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
851873
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
852874

853-
auto loc = newWarpOp.getLoc();
854875
rewriter.setInsertionPointAfter(newWarpOp);
855876
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
856-
rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
877+
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
857878
storeScatterOp->getAttrs());
858879
xegpu::removeLayoutAttrs(newOp);
859880
rewriter.eraseOp(storeScatterOp);
860881
return success();
861882
}
862883
};
863884

885+
/// Distribute a scattered load op. The logic and requirements are the same as
886+
/// for the scattered store distribution. The warpOp's payload vector is
887+
/// expected to be distributed by the load's result consumer.
888+
/// Example 1 (no chunk size):
889+
/// %mask = producer_op : vector<16xi1>
890+
/// %offset = producer_op : vector<16xindex>
891+
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
892+
/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
893+
/// To
894+
/// %mask = producer_op : vector<1xi1>
895+
/// %offset = producer_op : vector<1xindex>
896+
/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
897+
/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
898+
/// Example 2 (chunk size, same mask and offsets):
899+
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
900+
/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
901+
/// To
902+
/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
903+
/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
864904
struct LoadDistribution final : public gpu::WarpDistributionPattern {
865905
using gpu::WarpDistributionPattern::WarpDistributionPattern;
866906
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
867907
PatternRewriter &rewriter) const override {
868-
OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
869-
if (!isa<xegpu::LoadGatherOp>(op))
870-
return false;
871-
auto yield = cast<gpu::YieldOp>(
872-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
873-
return yield->getPrevNode() == op;
908+
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
909+
// Check if the yield operand that was produced by the *last* scattered
910+
// load op to avoid sinking it before barriers (maintain memory order).
911+
return isa<xegpu::LoadGatherOp>(op) &&
912+
warpOp.getTerminator()->getPrevNode() == op;
874913
});
875-
if (!yieldOperand)
914+
if (!producedByLastLoad)
876915
return rewriter.notifyMatchFailure(
877-
warpOp, "warp result is not a xegpu::LoadGatherOp op");
916+
warpOp, "The last op is not xegpu::LoadGatherOp");
878917

879918
auto loadGatherOp =
880-
yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
881-
if (!loadGatherOp.getOffsets())
882-
return rewriter.notifyMatchFailure(loadGatherOp,
883-
"Load op must have offsets argument");
884-
VectorType offsetsTy =
885-
cast<VectorType>(loadGatherOp.getOffsets().getType());
919+
producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
920+
auto offsets = loadGatherOp.getOffsets();
921+
if (!offsets || !isa<VectorType>(offsets.getType()))
922+
return rewriter.notifyMatchFailure(
923+
loadGatherOp, "Load op must have a vector of offsets argument");
924+
VectorType offsetsTy = cast<VectorType>(offsets.getType());
886925
if (offsetsTy.getRank() != 1)
887926
return rewriter.notifyMatchFailure(loadGatherOp,
888927
"Expected 1D offsets vector");
889928

890929
SmallVector<size_t> newRetIndices;
891930
SmallVector<Value> operands = loadGatherOp->getOperands();
892-
SmallVector<Type> operandTypes =
931+
SmallVector<Type> operandTypesToYield =
893932
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
894-
// Assume offset and mask pproducers will be distributed as well.
895-
operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
896-
operandTypes[2] = VectorType::get(
897-
{1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
933+
// Assume offset and mask producers will be distributed as well.
934+
operandTypesToYield[1] =
935+
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
936+
operandTypesToYield[2] =
937+
VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType()));
898938

899939
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
900-
rewriter, warpOp, operands, operandTypes, newRetIndices);
940+
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
901941

902942
SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
903943
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
904944

905-
const unsigned operandIdx = yieldOperand->getOperandNumber();
945+
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
906946
VectorType loadVecTy =
907947
cast<VectorType>(warpOp.getResult(operandIdx).getType());
908948
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
909949

910-
auto loc = newWarpOp.getLoc();
911950
rewriter.setInsertionPointAfter(newWarpOp);
912951
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
913-
loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
952+
newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
953+
loadGatherOp->getAttrs());
914954
Value distributedVal = newWarpOp.getResult(operandIdx);
915955
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
916956
return success();
@@ -948,6 +988,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
948988
if (!isa<VectorType>(operand.get().getType()))
949989
continue;
950990

991+
// Vectors operands of these ops have a fixed and implicit layout.
951992
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
952993
continue;
953994
auto layout =

0 commit comments

Comments
 (0)