Skip to content

Commit daa143f

Browse files
committed
Add layout-based distribution
1 parent a4d4e66 commit daa143f

File tree

4 files changed

+182
-58
lines changed

4 files changed

+182
-58
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
194194
}
195195

196196
/// Helper to get the default layout for a vector type.
197-
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
197+
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
198+
bool scattered = false) {
198199
// Expecting a 1D or 2D vector.
199200
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
200201
"Expected 1D or 2D vector.");
@@ -207,14 +208,23 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
207208
// Packing factor is determined by the element type bitwidth.
208209
int packingFactor = 1;
209210
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
211+
if (scattered) {
212+
packingFactor =
213+
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
214+
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
215+
: 1;
216+
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
217+
LaneData({1, packingFactor}));
218+
}
210219
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
211220
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
212221
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
213222
LaneData({1, packingFactor}));
214223
}
215224

216225
/// Helper to get the default layout for a vector type.
217-
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
226+
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
227+
bool scattered = false) {
218228
// Expecting a 1D or 2D vector.
219229
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
220230
"Expected 1D or 2D TensorDesc.");
@@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
227237
// Packing factor is determined by the element type bitwidth.
228238
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
229239

230-
if (tdescTy.isScattered()) {
240+
if (scattered) {
231241
int packingFactor =
232242
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
233243
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -541,21 +551,27 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
541551
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
542552
}
543553

544-
/// Propagate the layout of the result to the tensor descriptor and mask
554+
/// Propagate the layout of the result to the tensor descriptor, mask and offset
545555
/// operands in LoadGatherOp.
546556
void LayoutInfoPropagation::visitLoadGatherOp(
547557
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
548558
ArrayRef<const LayoutInfoLattice *> results) {
549-
// The layout is strictly determined by the tensor descriptor type.
550-
LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
559+
// The layout is strictly determined by the payload type.
560+
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
561+
assert(payloadTy && "Only vector payload distribution is supported");
562+
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
551563

552564
// Mask operand should have 1D default layout.
553565
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
554566

555567
// Propagate the new layout to the tensor descriptor operand.
556-
propagateIfChanged(operands[0], operands[0]->meet(layout));
557-
// Propagate the new layout to the mask operand.
568+
if (isa<xegpu::TensorDescType>(load.getSourceType()))
569+
propagateIfChanged(operands[0], operands[0]->meet(layout));
570+
// Propagate the new layout to the mask and optional offset operand.
558571
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
572+
if (load.getOffsets()) {
573+
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
574+
}
559575
}
560576

561577
/// Propagate the layout of the descriptor to the vector offset operand in
@@ -572,31 +588,36 @@ void LayoutInfoPropagation::visitCreateDescOp(
572588
propagateIfChanged(operands[1], operands[1]->meet(layout));
573589
}
574590

575-
/// Set the layout for the value, tensor descriptor, and mask operands in the
576-
/// StoreScatterOp.
591+
/// Set the layout for the value, tensor descriptor, offset and mask operands in
592+
/// the StoreScatterOp.
577593
void LayoutInfoPropagation::visitStoreScatterOp(
578594
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
579595
ArrayRef<const LayoutInfoLattice *> results) {
580596
// Currently, for 2D StoreScatterOp we expect that the height dimension of
581597
// the tensor descriptor is equal to the subgroup size. This is ensured by
582598
// the op verifier.
583-
ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
584-
if (tdescShape.size() > 1)
599+
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
600+
assert(payloadTy && "Only vector payload distribution is supported");
601+
auto payloadShape = payloadTy.getShape();
602+
if (payloadShape.size() > 1)
585603
assert(
586-
tdescShape[0] == xegpu::targetinfo::subgroupSize &&
604+
payloadShape[0] == xegpu::targetinfo::subgroupSize &&
587605
"Expected the first dimension of 2D tensor descriptor to be equal to "
588606
"subgroup size.");
589607

590-
LayoutInfo layout =
591-
getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
608+
LayoutInfo payloadLayout =
609+
getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
592610

593-
// Propagate the value layout.
594-
propagateIfChanged(operands[0], operands[0]->meet(layout));
595-
// Propagate the tensor descriptor layout.
596-
propagateIfChanged(operands[1], operands[1]->meet(layout));
597-
// Use default 1D layout for mask operand.
598611
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
612+
// Propagate the payload operand layout
613+
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
614+
// Propagate the destination (if tdesc) operand layout
615+
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
616+
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
617+
// Propagate the new layout to the mask and optional offset operand.
599618
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
619+
if (storeScatter.getOffsets())
620+
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
600621
}
601622

602623
namespace {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
844844
return rewriter.notifyMatchFailure(
845845
storeScatterOp, "Store op must have a vector of offsets argument");
846846
VectorType offsetsTy = cast<VectorType>(offsets.getType());
847-
if (offsetsTy.getRank() != 1)
847+
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
848+
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
848849
return rewriter.notifyMatchFailure(storeScatterOp,
849-
"Expected 1D offsets vector");
850+
"Expected 1D offsets and mask vector");
850851
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
851852
assert(storeVecTy.getRank() <= 2 &&
852853
"Expected at most 2D result at SG level");
@@ -855,17 +856,45 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
855856
distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
856857
else // rank 1
857858
distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
858-
859-
SmallVector<size_t> newRetIndices;
860-
SmallVector<Value> operands = storeScatterOp->getOperands();
861-
SmallVector<Type> operandTypesToYield =
862-
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
863-
operandTypesToYield[0] = distStoreVecTy;
864859
// Assume offset and mask producers will be distributed as well.
865-
operandTypesToYield[2] =
860+
VectorType distOffsetsTy =
866861
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
867-
operandTypesToYield[3] = VectorType::get(
862+
VectorType distMaskTy = VectorType::get(
868863
{1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
864+
std::string layoutPayloadName =
865+
xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
866+
std::string layoutOffsetsName =
867+
xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
868+
std::string layoutMaskName =
869+
xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
870+
871+
xegpu::LayoutAttr layoutPayload =
872+
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
873+
xegpu::LayoutAttr layoutOffsets =
874+
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
875+
xegpu::LayoutAttr layoutMask =
876+
storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
877+
878+
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
879+
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
880+
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
881+
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
882+
FailureOr<VectorType> distMaskByWarpOpOrFailure =
883+
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
884+
if (failed(distStoreVecByWarpOpOrFailure) ||
885+
failed(distOffsetsByWarpOpOrFailure) ||
886+
failed(distMaskByWarpOpOrFailure)) {
887+
storeScatterOp.emitWarning(
888+
"Some vector operands have no layouts, using defaults instead.");
889+
}
890+
distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
891+
distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
892+
distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
893+
894+
SmallVector<size_t> newRetIndices;
895+
SmallVector<Value> operands = storeScatterOp->getOperands();
896+
SmallVector<Type> operandTypesToYield = {
897+
distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy};
869898

870899
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
871900
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -918,23 +947,47 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
918947
auto loadGatherOp =
919948
producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
920949
auto offsets = loadGatherOp.getOffsets();
921-
if (!offsets || !isa<VectorType>(offsets.getType()))
950+
if (!offsets || !isa<VectorType>(offsets.getType()) ||
951+
!isa<VectorType>(loadGatherOp.getMask().getType()))
922952
return rewriter.notifyMatchFailure(
923-
loadGatherOp, "Load op must have a vector of offsets argument");
953+
loadGatherOp,
954+
"Load op must have a vector arguments for offsets and mask");
924955
VectorType offsetsTy = cast<VectorType>(offsets.getType());
925-
if (offsetsTy.getRank() != 1)
956+
VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
957+
if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
926958
return rewriter.notifyMatchFailure(loadGatherOp,
927-
"Expected 1D offsets vector");
959+
"Expected 1D offsets and mask vector");
960+
// Assume offset and mask producers will be distributed as well.
961+
VectorType distOffsetsTy =
962+
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
963+
VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy));
964+
965+
std::string layoutOffsetsName =
966+
xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
967+
std::string layoutMaskName =
968+
xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
969+
970+
xegpu::LayoutAttr layoutOffsets =
971+
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
972+
xegpu::LayoutAttr layoutMask =
973+
loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
974+
975+
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
976+
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
977+
FailureOr<VectorType> distMaskByWarpOpOrFailure =
978+
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
979+
if (failed(distOffsetsByWarpOpOrFailure) ||
980+
failed(distMaskByWarpOpOrFailure)) {
981+
loadGatherOp.emitWarning(
982+
"Some vector operands have no layouts, using defaults instead.");
983+
}
984+
distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
985+
distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
928986

929987
SmallVector<size_t> newRetIndices;
930988
SmallVector<Value> operands = loadGatherOp->getOperands();
931-
SmallVector<Type> operandTypesToYield =
932-
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
933-
// Assume offset and mask producers will be distributed as well.
934-
operandTypesToYield[1] =
935-
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
936-
operandTypesToYield[2] =
937-
VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType()));
989+
SmallVector<Type> operandTypesToYield = {operands[0].getType(),
990+
distOffsetsTy, distMaskTy};
938991

939992
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
940993
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -951,6 +1004,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
9511004
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
9521005
newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
9531006
loadGatherOp->getAttrs());
1007+
xegpu::removeLayoutAttrs(newOp);
9541008
Value distributedVal = newWarpOp.getResult(operandIdx);
9551009
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
9561010
return success();
@@ -990,7 +1044,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
9901044

9911045
// Vectors operands of these ops have a fixed and implicit layout.
9921046
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
993-
continue;
1047+
continue;
9941048
auto layout =
9951049
xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
9961050
if (!layout) {

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
162162
return
163163
}
164164

165+
// -----
166+
// CHECK-LABEL: func.func @scatter_ops_chunksize(
167+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
168+
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
169+
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
170+
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
171+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
172+
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
173+
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
174+
%1 = arith.constant dense<1>: vector<16xi1>
175+
%offset = arith.constant dense<12> : vector<16xindex>
176+
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
177+
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
178+
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
179+
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
180+
return
181+
}
182+
183+
// -----
184+
// CHECK-LABEL: func.func @scatter_ops(
185+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
186+
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
187+
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
188+
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
189+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
190+
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
191+
func.func @scatter_ops(%src: memref<256xf16>) {
192+
%1 = arith.constant dense<1>: vector<16xi1>
193+
%offset = arith.constant dense<12> : vector<16xindex>
194+
%3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
195+
xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
196+
return
197+
}
198+
165199
// -----
166200
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
167201
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {

0 commit comments

Comments
 (0)