Skip to content

Commit bcc9d85

Browse files
committed
Address feedback
1 parent daa143f commit bcc9d85

File tree

3 files changed

+39
-53
lines changed

3 files changed

+39
-53
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
195195

196196
/// Helper to get the default layout for a vector type.
197197
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
198-
bool scattered = false) {
198+
bool isScattered = false) {
199199
// Expecting a 1D or 2D vector.
200200
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
201201
"Expected 1D or 2D vector.");
@@ -208,7 +208,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
208208
// Packing factor is determined by the element type bitwidth.
209209
int packingFactor = 1;
210210
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
211-
if (scattered) {
211+
if (isScattered) {
212212
packingFactor =
213213
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
214214
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -224,7 +224,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
224224

225225
/// Helper to get the default layout for a vector type.
226226
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
227-
bool scattered = false) {
227+
bool isScattered = false) {
228228
// Expecting a 1D or 2D vector.
229229
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
230230
"Expected 1D or 2D TensorDesc.");
@@ -237,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
237237
// Packing factor is determined by the element type bitwidth.
238238
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
239239

240-
if (scattered) {
240+
if (isScattered) {
241241
int packingFactor =
242242
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
243243
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -558,7 +558,10 @@ void LayoutInfoPropagation::visitLoadGatherOp(
558558
ArrayRef<const LayoutInfoLattice *> results) {
559559
// The layout is strictly determined by the payload type.
560560
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
561-
assert(payloadTy && "Only vector payload distribution is supported");
561+
if (!payloadTy) {
562+
load.emitWarning("Not propagating, non-vector payload supplied.");
563+
return;
564+
}
562565
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
563566

564567
// Mask operand should have 1D default layout.
@@ -569,9 +572,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
569572
propagateIfChanged(operands[0], operands[0]->meet(layout));
570573
// Propagate the new layout to the mask and optional offset operand.
571574
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
572-
if (load.getOffsets()) {
575+
if (load.getOffsets())
573576
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
574-
}
575577
}
576578

577579
/// Propagate the layout of the descriptor to the vector offset operand in
@@ -597,7 +599,10 @@ void LayoutInfoPropagation::visitStoreScatterOp(
597599
// the tensor descriptor is equal to the subgroup size. This is ensured by
598600
// the op verifier.
599601
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
600-
assert(payloadTy && "Only vector payload distribution is supported");
602+
if (!payloadTy) {
603+
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
604+
return;
605+
}
601606
auto payloadShape = payloadTy.getShape();
602607
if (payloadShape.size() > 1)
603608
assert(

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -849,18 +849,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
849849
return rewriter.notifyMatchFailure(storeScatterOp,
850850
"Expected 1D offsets and mask vector");
851851
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
852-
assert(storeVecTy.getRank() <= 2 &&
853-
"Expected at most 2D result at SG level");
854-
VectorType distStoreVecTy;
855-
if (storeVecTy.getRank() == 2)
856-
distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
857-
else // rank 1
858-
distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
859-
// Assume offset and mask producers will be distributed as well.
860-
VectorType distOffsetsTy =
861-
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
862-
VectorType distMaskTy = VectorType::get(
863-
{1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
852+
if (storeVecTy.getRank() > 2)
853+
return rewriter.notifyMatchFailure(
854+
storeScatterOp, "Expected at most 2D result at SG level");
855+
864856
std::string layoutPayloadName =
865857
xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
866858
std::string layoutOffsetsName =
@@ -884,17 +876,20 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
884876
if (failed(distStoreVecByWarpOpOrFailure) ||
885877
failed(distOffsetsByWarpOpOrFailure) ||
886878
failed(distMaskByWarpOpOrFailure)) {
887-
storeScatterOp.emitWarning(
879+
return rewriter.notifyMatchFailure(
880+
storeScatterOp,
888881
"Some vector operands have no layouts, using defaults instead.");
889882
}
890-
distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
891-
distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
892-
distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
883+
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
884+
VectorType expectedPayloadTy = VectorType::get(
885+
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
893886

894887
SmallVector<size_t> newRetIndices;
895888
SmallVector<Value> operands = storeScatterOp->getOperands();
896889
SmallVector<Type> operandTypesToYield = {
897-
distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy};
890+
expectedPayloadTy, operands[1].getType(),
891+
distOffsetsByWarpOpOrFailure.value(),
892+
distMaskByWarpOpOrFailure.value()};
898893

899894
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
900895
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -958,10 +953,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
958953
return rewriter.notifyMatchFailure(loadGatherOp,
959954
"Expected 1D offsets and mask vector");
960955
// Assume offset and mask producers will be distributed as well.
961-
VectorType distOffsetsTy =
962-
VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
963-
VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy));
964-
965956
std::string layoutOffsetsName =
966957
xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
967958
std::string layoutMaskName =
@@ -978,16 +969,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
978969
getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
979970
if (failed(distOffsetsByWarpOpOrFailure) ||
980971
failed(distMaskByWarpOpOrFailure)) {
981-
loadGatherOp.emitWarning(
972+
return rewriter.notifyMatchFailure(
973+
loadGatherOp,
982974
"Some vector operands have no layouts, using defaults instead.");
983975
}
984-
distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
985-
distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
986976

987977
SmallVector<size_t> newRetIndices;
988978
SmallVector<Value> operands = loadGatherOp->getOperands();
989-
SmallVector<Type> operandTypesToYield = {operands[0].getType(),
990-
distOffsetsTy, distMaskTy};
979+
SmallVector<Type> operandTypesToYield = {
980+
operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
981+
distMaskByWarpOpOrFailure.value()};
991982

992983
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
993984
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -998,7 +989,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
998989
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
999990
VectorType loadVecTy =
1000991
cast<VectorType>(warpOp.getResult(operandIdx).getType());
1001-
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
1002992

1003993
rewriter.setInsertionPointAfter(newWarpOp);
1004994
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -330,24 +330,15 @@ gpu.module @test {
330330
gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
331331
%1 = arith.constant dense<1>: vector<16xi1>
332332
%offset = arith.constant dense<12> : vector<16xindex>
333-
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
334-
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
335-
gpu.return
336-
}
337-
}
338-
339-
// -----
340-
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
341-
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
342-
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
343-
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
344-
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
345-
gpu.module @test {
346-
gpu.func @scatter_ops(%src: memref<256xf16>) {
347-
%1 = arith.constant dense<1>: vector<16xi1>
348-
%offset = arith.constant dense<12> : vector<16xindex>
349-
%3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
350-
xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
333+
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
334+
layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
335+
layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
336+
} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
337+
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> {
338+
layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
339+
layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
340+
layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
341+
} : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
351342
gpu.return
352343
}
353344
}

0 commit comments

Comments
 (0)