@@ -875,23 +875,29 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
875875 storeScatterOp,
876876 " Some vector operands have no layouts, using defaults instead." );
877877 }
878- VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
879- VectorType expectedPayloadTy = VectorType::get (
880- {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
878+ // Distributed store payload type according to the lane layout.
879+ VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value ();
880+ // Expected distributed payload type is always 1D.
881+ VectorType expectedPayloadTy =
882+ VectorType::get ({distPayloadTyByWarpOp.getNumElements ()},
883+ distPayloadTyByWarpOp.getElementType ());
881884
882885 SmallVector<size_t > newRetIndices;
883886 SmallVector<Value> operands = storeScatterOp->getOperands ();
884887 SmallVector<Type> operandTypesToYield = {
885- expectedPayloadTy , operands[1 ].getType (),
888+ distPayloadTyByWarpOp , operands[1 ].getType (),
886889 distOffsetsByWarpOpOrFailure.value (),
887890 distMaskByWarpOpOrFailure.value ()};
888891
889892 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
890893 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
891894 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
892895 newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
893-
896+ // The payload operand may need type adjustment due to mismatch between warp
897+ // distributed type and expected SIMT type.
894898 rewriter.setInsertionPointAfter (newWarpOp);
899+ newStoreScatterOpOperands[0 ] = resolveDistributedTy (
900+ newStoreScatterOpOperands[0 ], expectedPayloadTy, rewriter);
895901 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
896902 rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
897903 storeScatterOp->getAttrs ());
@@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
976982 distMaskByWarpOpOrFailure.value ()};
977983
978984 const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
979- VectorType loadVecTy =
985+ VectorType distResultTy =
980986 cast<VectorType>(warpOp.getResult (operandIdx).getType ());
987+ // Distributed load op will always be 1D.
988+ VectorType loadVecTy = VectorType::get ({distResultTy.getNumElements ()},
989+ distResultTy.getElementType ());
981990
982991 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
983992 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -991,7 +1000,10 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
9911000 loadGatherOp->getAttrs ());
9921001 xegpu::removeLayoutAttrs (newOp);
9931002 Value distributedVal = newWarpOp.getResult (operandIdx);
994- rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1003+ // Resolve the output type and replace all uses.
1004+ rewriter.replaceAllUsesWith (
1005+ distributedVal,
1006+ resolveDistributedTy (newOp.getResult (), distResultTy, rewriter));
9951007 return success ();
9961008 }
9971009};
@@ -1107,7 +1119,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11071119 return failure ();
11081120 auto reductionOp =
11091121 cast<vector::MultiDimReductionOp>(yieldOperand->get ().getDefiningOp ());
1110- unsigned operandNumber = yieldOperand->getOperandNumber ();
1122+ unsigned operandIdx = yieldOperand->getOperandNumber ();
11111123 VectorType sourceType = reductionOp.getSourceVectorType ();
11121124 // Only 2D vectors are supported.
11131125 if (sourceType.getRank () != 2 )
@@ -1121,7 +1133,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11211133 warpOp, " Only 1 reduction dimension is supported." );
11221134 int64_t reductionDim = reductionDims[0 ];
11231135 VectorType distributedResultType =
1124- cast<VectorType>(warpOp.getResult (operandNumber ).getType ());
1136+ cast<VectorType>(warpOp.getResult (operandIdx ).getType ());
11251137 VectorType resultType = cast<VectorType>(reductionOp.getType ());
11261138 xegpu::DistributeLayoutAttr sourceLayout =
11271139 xegpu::getDistributeLayoutAttr (reductionOp.getSource ());
@@ -1184,7 +1196,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
11841196 cast<TypedValue<VectorType>>(newWarpOp->getResult (newRetIndices[1 ])),
11851197 reductionOp.getKind (), reductionDim, reductionOp.getLoc (), rewriter);
11861198 // Replace the warp op result with the final result.
1187- rewriter.replaceAllUsesWith (reductionOp .getResult (), result);
1199+ rewriter.replaceAllUsesWith (newWarpOp .getResult (operandIdx ), result);
11881200 return success ();
11891201 }
11901202 // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
0 commit comments