@@ -875,23 +875,29 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
875875 storeScatterOp,
876876 " Some vector operands have no layouts, using defaults instead." );
877877 }
878- VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
879- VectorType expectedPayloadTy = VectorType::get (
880- {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
878+ // Distributed store payload type according to the lane layout.
879+ VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value ();
880+ // Expected distributed payload type is always 1D.
881+ VectorType expectedPayloadTy =
882+ VectorType::get ({distPayloadTyByWarpOp.getNumElements ()},
883+ distPayloadTyByWarpOp.getElementType ());
881884
882885 SmallVector<size_t > newRetIndices;
883886 SmallVector<Value> operands = storeScatterOp->getOperands ();
884887 SmallVector<Type> operandTypesToYield = {
885- expectedPayloadTy , operands[1 ].getType (),
888+ distPayloadTyByWarpOp , operands[1 ].getType (),
886889 distOffsetsByWarpOpOrFailure.value (),
887890 distMaskByWarpOpOrFailure.value ()};
888891
889892 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
890893 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
891894 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
892895 newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
893-
896+ // The payload operand may need type adjustment due to mismatch between warp
897+ // distributed type and expected SIMT type.
894898 rewriter.setInsertionPointAfter (newWarpOp);
899+ newStoreScatterOpOperands[0 ] = resolveDistributedTy (
900+ newStoreScatterOpOperands[0 ], expectedPayloadTy, rewriter);
895901 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
896902 rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
897903 storeScatterOp->getAttrs ());
@@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
976982 distMaskByWarpOpOrFailure.value ()};
977983
978984 const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
979- VectorType loadVecTy =
985+ VectorType distResultTy =
980986 cast<VectorType>(warpOp.getResult (operandIdx).getType ());
987+ // Distributed load op will always be 1D.
988+ VectorType loadVecTy = VectorType::get ({distResultTy.getNumElements ()},
989+ distResultTy.getElementType ());
981990
982991 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
983992 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -991,13 +1000,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
9911000 loadGatherOp->getAttrs ());
9921001 xegpu::removeLayoutAttrs (newOp);
9931002 Value distributedVal = newWarpOp.getResult (operandIdx);
994- rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1003+ // Resolve the output type and replace all uses.
1004+ rewriter.replaceAllUsesWith (
1005+ distributedVal,
1006+ resolveDistributedTy (newOp.getResult (), distResultTy, rewriter));
9951007 return success ();
9961008 }
9971009};
9981010
9991011// / Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1000- // / VectorReductionOps.
1012+ // / VectorReductionOps. We also insert layouts for the newly created ops.
10011013static Value lowerToVectorReductions (TypedValue<VectorType> src,
10021014 TypedValue<VectorType> acc,
10031015 vector::CombiningKind kind,
@@ -1014,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10141026 Value reductionResult = arith::ConstantOp::create (
10151027 rewriter, loc, acc.getType (),
10161028 DenseElementsAttr::get (acc.getType (), zeroAttr));
1029+ // Reduction result should have the same layout as the accumulator.
1030+ xegpu::setDistributeLayoutAttr (cast<OpResult>(reductionResult),
1031+ xegpu::getDistributeLayoutAttr (acc));
10171032 // For each slice of the source, extract the slice vector, do a reduction
10181033 // and, insert the reduced value back to the result vector.
10191034 for (int i = 0 ; i < nSlices; ++i) {
@@ -1029,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10291044 vector::ExtractStridedSliceOp::create (rewriter, loc, src, sliceOffsets,
10301045 sliceSizes, {1 , 1 });
10311046 int64_t nSliceElements = extractOp.getResult ().getType ().getNumElements ();
1032- Value slice = vector::ShapeCastOp::create (
1047+ vector::ShapeCastOp slice = vector::ShapeCastOp::create (
10331048 rewriter, loc,
10341049 VectorType::get ({nSliceElements}, sourceType.getElementType ()),
10351050 extractOp.getResult ());
1051+ // Shape cast is currently handled in xegpu side. So layouts must be
1052+ // retained during lowering. Shape cast output has the same layout as the
1053+ // accumulator. Shape cast source has the same layout as the original
1054+ // reduction source.
1055+ // TODO: other ops generated here may also need layout attributes.
1056+ xegpu::setDistributeLayoutAttr (slice->getOpOperand (0 ),
1057+ xegpu::getDistributeLayoutAttr (src));
1058+ xegpu::setDistributeLayoutAttr (slice->getOpResult (0 ),
1059+ xegpu::getDistributeLayoutAttr (acc));
1060+ // Extract and reduction results in scalars, so no result layout is needed.
10361061 Value accExtract = vector::ExtractOp::create (rewriter, loc, acc, i);
1037- Value reduction =
1038- vector::ReductionOp::create ( rewriter, loc, kind, slice, accExtract);
1062+ Value reduction = vector::ReductionOp::create (
1063+ rewriter, loc, kind, slice. getResult () , accExtract);
10391064 reductionResult =
10401065 vector::InsertOp::create (rewriter, loc, reduction, reductionResult, i);
10411066 }
@@ -1107,7 +1132,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11071132 return failure ();
11081133 auto reductionOp =
11091134 cast<vector::MultiDimReductionOp>(yieldOperand->get ().getDefiningOp ());
1110- unsigned operandNumber = yieldOperand->getOperandNumber ();
1135+ unsigned operandIdx = yieldOperand->getOperandNumber ();
11111136 VectorType sourceType = reductionOp.getSourceVectorType ();
11121137 // Only 2D vectors are supported.
11131138 if (sourceType.getRank () != 2 )
@@ -1121,7 +1146,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11211146 warpOp, " Only 1 reduction dimension is supported." );
11221147 int64_t reductionDim = reductionDims[0 ];
11231148 VectorType distributedResultType =
1124- cast<VectorType>(warpOp.getResult (operandNumber ).getType ());
1149+ cast<VectorType>(warpOp.getResult (operandIdx ).getType ());
11251150 VectorType resultType = cast<VectorType>(reductionOp.getType ());
11261151 xegpu::DistributeLayoutAttr sourceLayout =
11271152 xegpu::getDistributeLayoutAttr (reductionOp.getSource ());
@@ -1184,7 +1209,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11841209 cast<TypedValue<VectorType>>(newWarpOp->getResult (newRetIndices[1 ])),
11851210 reductionOp.getKind (), reductionDim, reductionOp.getLoc (), rewriter);
11861211 // Replace the warp op result with the final result.
1187- rewriter.replaceAllUsesWith (reductionOp .getResult (), result);
1212+ rewriter.replaceAllUsesWith (newWarpOp .getResult (operandIdx ), result);
11881213 return success ();
11891214 }
11901215 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
@@ -1217,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
12171242 auto resultDistTy =
12181243 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
12191244 xegpu::DistributeLayoutAttr sourceLayout =
1220- xegpu::getDistributeLayoutAttr (shapeCastOp. getSource ( ));
1245+ xegpu::getDistributeLayoutAttr (shapeCastOp-> getOpOperand ( 0 ));
12211246 xegpu::DistributeLayoutAttr resultLayout =
12221247 xegpu::getDistributeLayoutAttr (shapeCastOp.getResult ());
12231248 if (!sourceLayout || !resultLayout)
@@ -1403,11 +1428,6 @@ namespace {
14031428struct XeGPUSubgroupDistributePass final
14041429 : public xegpu::impl::XeGPUSubgroupDistributeBase<
14051430 XeGPUSubgroupDistributePass> {
1406- XeGPUSubgroupDistributePass () = default ;
1407- XeGPUSubgroupDistributePass (const XeGPUSubgroupDistributePass &other) =
1408- default ;
1409- XeGPUSubgroupDistributePass (xegpu::XeGPUSubgroupDistributeOptions options)
1410- : XeGPUSubgroupDistributeBase(options) {}
14111431 void runOnOperation () override ;
14121432};
14131433} // namespace
@@ -1515,10 +1535,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
15151535 return laneVal;
15161536 };
15171537
1518- if (enableSGReductions)
1519- vector::populateDistributeReduction (
1520- patterns, warpReduction,
1521- /* pattern benefit=*/ regularPatternBenefit);
1538+ vector::populateDistributeReduction (
1539+ patterns, warpReduction,
1540+ /* pattern benefit=*/ regularPatternBenefit);
15221541
15231542 vector::populatePropagateWarpVectorDistributionPatterns (
15241543 patterns, distributionFn, shuffleFn,
0 commit comments