@@ -1027,22 +1027,28 @@ struct WgToSgVectorShapeCastOp
10271027 }
10281028};
10291029
1030- // Pattern for lowering vector.multi_reduction op to subgroup level.
1030+ // / Pattern for lowering vector.multi_reduction op to subgroup level.
1031+ // / Current limitation: only support 2D->1D reduction with single reduction
1032+ // / dimension, and the sg_layout in the reduced dimension being 1
1033+ // / so that reduction is local to subgroup & no cross-subgroup communication is
1034+ // / needed.
1035+ // / TODO: Add cases to handle more general situations which require SLM access.
10311036struct WgToSgMultiDimReductionOp
10321037 : public OpConversionPattern<vector::MultiDimReductionOp> {
10331038 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
10341039
10351040 LogicalResult
10361041 matchAndRewrite (vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
10371042 ConversionPatternRewriter &rewriter) const override {
1038- VectorType srcType = dyn_cast<VectorType>( op.getSource (). getType () );
1043+ VectorType srcType = op.getSourceVectorType ( );
10391044 VectorType dstType = dyn_cast<VectorType>(op.getResult ().getType ());
1040- if (!srcType || ! dstType)
1045+ if (!dstType)
10411046 return failure ();
10421047
1043- // TODO: generalize it
1044- auto srcShape = srcType.getShape ();
1045- auto dstShape = dstType.getShape ();
1048+ SmallVector<int64_t > srcShape (srcType.getShape ().begin (),
1049+ srcType.getShape ().end ());
1050+ SmallVector<int64_t > dstShape (dstType.getShape ().begin (),
1051+ dstType.getShape ().end ());
10461052 if (srcShape.size () != 2 || dstShape.size () != 1 )
10471053 return failure ();
10481054
@@ -1051,7 +1057,8 @@ struct WgToSgMultiDimReductionOp
10511057 if (!layout || !layout.isForWorkgroup ())
10521058 return failure ();
10531059
1054- auto reductionDims = op.getReductionDims ();
1060+ SmallVector<int64_t > reductionDims (op.getReductionDims ().begin (),
1061+ op.getReductionDims ().end ());
10551062 if (reductionDims.size () != 1 )
10561063 return failure ();
10571064
@@ -1069,8 +1076,8 @@ struct WgToSgMultiDimReductionOp
10691076 SmallVector<Value> newReductions;
10701077 for (auto [sgSrc, sgAcc] :
10711078 llvm::zip (adaptor.getSource (), adaptor.getAcc ())) {
1072- auto newOp = rewriter. create < vector::MultiDimReductionOp> (
1073- op.getLoc (), newDstType, op.getKind (), sgSrc, sgAcc,
1079+ auto newOp = vector::MultiDimReductionOp::create (
1080+ rewriter, op.getLoc (), newDstType, op.getKind (), sgSrc, sgAcc,
10741081 op.getReductionDims ());
10751082 if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
10761083 !layout.getEffectiveInstDataAsInt ().empty ())
0 commit comments