@@ -1009,7 +1009,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
10091009};
10101010
10111011// / Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1012- // / VectorReductionOps.
1012+ // / VectorReductionOps. We also insert layouts for the newly created ops.
10131013static Value lowerToVectorReductions (TypedValue<VectorType> src,
10141014 TypedValue<VectorType> acc,
10151015 vector::CombiningKind kind,
@@ -1026,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10261026 Value reductionResult = arith::ConstantOp::create (
10271027 rewriter, loc, acc.getType (),
10281028 DenseElementsAttr::get (acc.getType (), zeroAttr));
1029+ // Reduction result should have the same layout as the accumulator.
1030+ xegpu::setDistributeLayoutAttr (cast<OpResult>(reductionResult),
1031+ xegpu::getDistributeLayoutAttr (acc));
10291032 // For each slice of the source, extract the slice vector, do a reduction
10301033 // and, insert the reduced value back to the result vector.
10311034 for (int i = 0 ; i < nSlices; ++i) {
@@ -1041,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
10411044 vector::ExtractStridedSliceOp::create (rewriter, loc, src, sliceOffsets,
10421045 sliceSizes, {1 , 1 });
10431046 int64_t nSliceElements = extractOp.getResult ().getType ().getNumElements ();
1044- Value slice = vector::ShapeCastOp::create (
1047+ vector::ShapeCastOp slice = vector::ShapeCastOp::create (
10451048 rewriter, loc,
10461049 VectorType::get ({nSliceElements}, sourceType.getElementType ()),
10471050 extractOp.getResult ());
1051+ // Shape cast is currently handled in xegpu side. So layouts must be
1052+ // retained during lowering. Shape cast output has the same layout as the
1053+ // accumulator. Shape cast source has the same layout as the original
1054+ // reduction source.
1055+ // TODO: other ops generated here may also need layout attributes.
1056+ xegpu::setDistributeLayoutAttr (slice->getOpOperand (0 ),
1057+ xegpu::getDistributeLayoutAttr (src));
1058+ xegpu::setDistributeLayoutAttr (slice->getOpResult (0 ),
1059+ xegpu::getDistributeLayoutAttr (acc));
1060+ // Extract and reduction results in scalars, so no result layout is needed.
10481061 Value accExtract = vector::ExtractOp::create (rewriter, loc, acc, i);
1049- Value reduction =
1050- vector::ReductionOp::create ( rewriter, loc, kind, slice, accExtract);
1062+ Value reduction = vector::ReductionOp::create (
1063+ rewriter, loc, kind, slice. getResult () , accExtract);
10511064 reductionResult =
10521065 vector::InsertOp::create (rewriter, loc, reduction, reductionResult, i);
10531066 }
@@ -1229,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
12291242 auto resultDistTy =
12301243 cast<VectorType>(warpOp.getResult (operandNumber).getType ());
12311244 xegpu::DistributeLayoutAttr sourceLayout =
1232- xegpu::getDistributeLayoutAttr (shapeCastOp. getSource ( ));
1245+ xegpu::getDistributeLayoutAttr (shapeCastOp-> getOpOperand ( 0 ));
12331246 xegpu::DistributeLayoutAttr resultLayout =
12341247 xegpu::getDistributeLayoutAttr (shapeCastOp.getResult ());
12351248 if (!sourceLayout || !resultLayout)
0 commit comments