Skip to content

Commit 88ed30b

Browse files
committed
save work and bug fixes
1 parent 869cfca commit 88ed30b

File tree

3 files changed

+716
-607
lines changed

3 files changed

+716
-607
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -875,23 +875,29 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
875875
storeScatterOp,
876876
"Some vector operands have no layouts, using defaults instead.");
877877
}
878-
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
879-
VectorType expectedPayloadTy = VectorType::get(
880-
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
878+
// Distributed store payload type according to the lane layout.
879+
VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
880+
// Expected distributed payload type is always 1D.
881+
VectorType expectedPayloadTy =
882+
VectorType::get({distPayloadTyByWarpOp.getNumElements()},
883+
distPayloadTyByWarpOp.getElementType());
881884

882885
SmallVector<size_t> newRetIndices;
883886
SmallVector<Value> operands = storeScatterOp->getOperands();
884887
SmallVector<Type> operandTypesToYield = {
885-
expectedPayloadTy, operands[1].getType(),
888+
distPayloadTyByWarpOp, operands[1].getType(),
886889
distOffsetsByWarpOpOrFailure.value(),
887890
distMaskByWarpOpOrFailure.value()};
888891

889892
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
890893
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
891894
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
892895
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
893-
896+
// The payload operand may need type adjustment due to mismatch between warp
897+
// distributed type and expected SIMT type.
894898
rewriter.setInsertionPointAfter(newWarpOp);
899+
newStoreScatterOpOperands[0] = resolveDistributedTy(
900+
newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
895901
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
896902
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
897903
storeScatterOp->getAttrs());
@@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
976982
distMaskByWarpOpOrFailure.value()};
977983

978984
const unsigned operandIdx = producedByLastLoad->getOperandNumber();
979-
VectorType loadVecTy =
985+
VectorType distResultTy =
980986
cast<VectorType>(warpOp.getResult(operandIdx).getType());
987+
// Distributed load op will always be 1D.
988+
VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
989+
distResultTy.getElementType());
981990

982991
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
983992
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -991,7 +1000,10 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
9911000
loadGatherOp->getAttrs());
9921001
xegpu::removeLayoutAttrs(newOp);
9931002
Value distributedVal = newWarpOp.getResult(operandIdx);
994-
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
1003+
// Resolve the output type and replace all uses.
1004+
rewriter.replaceAllUsesWith(
1005+
distributedVal,
1006+
resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
9951007
return success();
9961008
}
9971009
};
@@ -1107,7 +1119,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11071119
return failure();
11081120
auto reductionOp =
11091121
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1110-
unsigned operandNumber = yieldOperand->getOperandNumber();
1122+
unsigned operandIdx = yieldOperand->getOperandNumber();
11111123
VectorType sourceType = reductionOp.getSourceVectorType();
11121124
// Only 2D vectors are supported.
11131125
if (sourceType.getRank() != 2)
@@ -1121,7 +1133,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11211133
warpOp, "Only 1 reduction dimension is supported.");
11221134
int64_t reductionDim = reductionDims[0];
11231135
VectorType distributedResultType =
1124-
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1136+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
11251137
VectorType resultType = cast<VectorType>(reductionOp.getType());
11261138
xegpu::DistributeLayoutAttr sourceLayout =
11271139
xegpu::getDistributeLayoutAttr(reductionOp.getSource());
@@ -1184,7 +1196,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11841196
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
11851197
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
11861198
// Replace the warp op result with the final result.
1187-
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1199+
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
11881200
return success();
11891201
}
11901202
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms

0 commit comments

Comments
 (0)