@@ -875,23 +875,29 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
875
875
storeScatterOp,
876
876
" Some vector operands have no layouts, using defaults instead." );
877
877
}
878
- VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
879
- VectorType expectedPayloadTy = VectorType::get (
880
- {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
878
+ // Distributed store payload type according to the lane layout.
879
+ VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value ();
880
+ // Expected distributed payload type is always 1D.
881
+ VectorType expectedPayloadTy =
882
+ VectorType::get ({distPayloadTyByWarpOp.getNumElements ()},
883
+ distPayloadTyByWarpOp.getElementType ());
881
884
882
885
SmallVector<size_t > newRetIndices;
883
886
SmallVector<Value> operands = storeScatterOp->getOperands ();
884
887
SmallVector<Type> operandTypesToYield = {
885
- expectedPayloadTy , operands[1 ].getType (),
888
+ distPayloadTyByWarpOp , operands[1 ].getType (),
886
889
distOffsetsByWarpOpOrFailure.value (),
887
890
distMaskByWarpOpOrFailure.value ()};
888
891
889
892
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
890
893
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
891
894
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
892
895
newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
893
-
896
+ // The payload operand may need type adjustment due to mismatch between warp
897
+ // distributed type and expected SIMT type.
894
898
rewriter.setInsertionPointAfter (newWarpOp);
899
+ newStoreScatterOpOperands[0 ] = resolveDistributedTy (
900
+ newStoreScatterOpOperands[0 ], expectedPayloadTy, rewriter);
895
901
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
896
902
rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
897
903
storeScatterOp->getAttrs ());
@@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
976
982
distMaskByWarpOpOrFailure.value ()};
977
983
978
984
const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
979
- VectorType loadVecTy =
985
+ VectorType distResultTy =
980
986
cast<VectorType>(warpOp.getResult (operandIdx).getType ());
987
+ // Distributed load op will always be 1D.
988
+ VectorType loadVecTy = VectorType::get ({distResultTy.getNumElements ()},
989
+ distResultTy.getElementType ());
981
990
982
991
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
983
992
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -991,13 +1000,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
991
1000
loadGatherOp->getAttrs ());
992
1001
xegpu::removeLayoutAttrs (newOp);
993
1002
Value distributedVal = newWarpOp.getResult (operandIdx);
994
- rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1003
+ // Resolve the output type and replace all uses.
1004
+ rewriter.replaceAllUsesWith (
1005
+ distributedVal,
1006
+ resolveDistributedTy (newOp.getResult (), distResultTy, rewriter));
995
1007
return success ();
996
1008
}
997
1009
};
998
1010
999
1011
// / Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1000
- // / VectorReductionOps.
1012
+ // / VectorReductionOps. We also insert layouts for the newly created ops.
1001
1013
static Value lowerToVectorReductions (TypedValue<VectorType> src,
1002
1014
TypedValue<VectorType> acc,
1003
1015
vector::CombiningKind kind,
@@ -1014,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
1014
1026
Value reductionResult = arith::ConstantOp::create (
1015
1027
rewriter, loc, acc.getType (),
1016
1028
DenseElementsAttr::get (acc.getType (), zeroAttr));
1029
+ // Reduction result should have the same layout as the accumulator.
1030
+ xegpu::setDistributeLayoutAttr (cast<OpResult>(reductionResult),
1031
+ xegpu::getDistributeLayoutAttr (acc));
1017
1032
// For each slice of the source, extract the slice vector, do a reduction
1018
1033
// and, insert the reduced value back to the result vector.
1019
1034
for (int i = 0 ; i < nSlices; ++i) {
@@ -1029,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
1029
1044
vector::ExtractStridedSliceOp::create (rewriter, loc, src, sliceOffsets,
1030
1045
sliceSizes, {1 , 1 });
1031
1046
int64_t nSliceElements = extractOp.getResult ().getType ().getNumElements ();
1032
- Value slice = vector::ShapeCastOp::create (
1047
+ vector::ShapeCastOp slice = vector::ShapeCastOp::create (
1033
1048
rewriter, loc,
1034
1049
VectorType::get ({nSliceElements}, sourceType.getElementType ()),
1035
1050
extractOp.getResult ());
1051
+ // Shape cast is currently handled in xegpu side. So layouts must be
1052
+ // retained during lowering. Shape cast output has the same layout as the
1053
+ // accumulator. Shape cast source has the same layout as the original
1054
+ // reduction source.
1055
+ // TODO: other ops generated here may also need layout attributes.
1056
+ xegpu::setDistributeLayoutAttr (slice->getOpOperand (0 ),
1057
+ xegpu::getDistributeLayoutAttr (src));
1058
+ xegpu::setDistributeLayoutAttr (slice->getOpResult (0 ),
1059
+ xegpu::getDistributeLayoutAttr (acc));
1060
+ // Extract and reduction results in scalars, so no result layout is needed.
1036
1061
Value accExtract = vector::ExtractOp::create (rewriter, loc, acc, i);
1037
- Value reduction =
1038
- vector::ReductionOp::create ( rewriter, loc, kind, slice, accExtract);
1062
+ Value reduction = vector::ReductionOp::create (
1063
+ rewriter, loc, kind, slice. getResult () , accExtract);
1039
1064
reductionResult =
1040
1065
vector::InsertOp::create (rewriter, loc, reduction, reductionResult, i);
1041
1066
}
@@ -1107,7 +1132,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1107
1132
return failure ();
1108
1133
auto reductionOp =
1109
1134
cast<vector::MultiDimReductionOp>(yieldOperand->get ().getDefiningOp ());
1110
- unsigned operandNumber = yieldOperand->getOperandNumber ();
1135
+ unsigned operandIdx = yieldOperand->getOperandNumber ();
1111
1136
VectorType sourceType = reductionOp.getSourceVectorType ();
1112
1137
// Only 2D vectors are supported.
1113
1138
if (sourceType.getRank () != 2 )
@@ -1121,7 +1146,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1121
1146
warpOp, " Only 1 reduction dimension is supported." );
1122
1147
int64_t reductionDim = reductionDims[0 ];
1123
1148
VectorType distributedResultType =
1124
- cast<VectorType>(warpOp.getResult (operandNumber ).getType ());
1149
+ cast<VectorType>(warpOp.getResult (operandIdx ).getType ());
1125
1150
VectorType resultType = cast<VectorType>(reductionOp.getType ());
1126
1151
xegpu::DistributeLayoutAttr sourceLayout =
1127
1152
xegpu::getDistributeLayoutAttr (reductionOp.getSource ());
@@ -1184,7 +1209,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1184
1209
cast<TypedValue<VectorType>>(newWarpOp->getResult (newRetIndices[1 ])),
1185
1210
reductionOp.getKind (), reductionDim, reductionOp.getLoc (), rewriter);
1186
1211
// Replace the warp op result with the final result.
1187
- rewriter.replaceAllUsesWith (reductionOp .getResult (), result);
1212
+ rewriter.replaceAllUsesWith (newWarpOp .getResult (operandIdx ), result);
1188
1213
return success ();
1189
1214
}
1190
1215
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms
@@ -1217,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1217
1242
auto resultDistTy =
1218
1243
cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1219
1244
xegpu::DistributeLayoutAttr sourceLayout =
1220
- xegpu::getDistributeLayoutAttr (shapeCastOp. getSource ( ));
1245
+ xegpu::getDistributeLayoutAttr (shapeCastOp-> getOpOperand ( 0 ));
1221
1246
xegpu::DistributeLayoutAttr resultLayout =
1222
1247
xegpu::getDistributeLayoutAttr (shapeCastOp.getResult ());
1223
1248
if (!sourceLayout || !resultLayout)
@@ -1403,11 +1428,6 @@ namespace {
1403
1428
struct XeGPUSubgroupDistributePass final
1404
1429
: public xegpu::impl::XeGPUSubgroupDistributeBase<
1405
1430
XeGPUSubgroupDistributePass> {
1406
- XeGPUSubgroupDistributePass () = default ;
1407
- XeGPUSubgroupDistributePass (const XeGPUSubgroupDistributePass &other) =
1408
- default ;
1409
- XeGPUSubgroupDistributePass (xegpu::XeGPUSubgroupDistributeOptions options)
1410
- : XeGPUSubgroupDistributeBase(options) {}
1411
1431
void runOnOperation () override ;
1412
1432
};
1413
1433
} // namespace
@@ -1515,10 +1535,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
1515
1535
return laneVal;
1516
1536
};
1517
1537
1518
- if (enableSGReductions)
1519
- vector::populateDistributeReduction (
1520
- patterns, warpReduction,
1521
- /* pattern benefit=*/ regularPatternBenefit);
1538
+ vector::populateDistributeReduction (
1539
+ patterns, warpReduction,
1540
+ /* pattern benefit=*/ regularPatternBenefit);
1522
1541
1523
1542
vector::populatePropagateWarpVectorDistributionPatterns (
1524
1543
patterns, distributionFn, shuffleFn,
0 commit comments