@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
198198 case vector::CombiningKind::ADD:
199199 case vector::CombiningKind::XOR:
200200 // Initialize reduction vector to: | 0 | .. | 0 | r |
201- return rewriter.create <vector::InsertElementOp >(
201+ return rewriter.create <vector::InsertOp >(
202202 loc, r, constantZero (rewriter, loc, vtp),
203203 constantIndex (rewriter, loc, 0 ));
204204 case vector::CombiningKind::MUL:
205205 // Initialize reduction vector to: | 1 | .. | 1 | r |
206- return rewriter.create <vector::InsertElementOp >(
206+ return rewriter.create <vector::InsertOp >(
207207 loc, r, constantOne (rewriter, loc, vtp),
208208 constantIndex (rewriter, loc, 0 ));
209209 case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
628628 const VL vl;
629629};
630630
631+ static LogicalResult cleanReducChain (PatternRewriter &rewriter, Operation *op,
632+ Value inp) {
633+ if (auto redOp = inp.getDefiningOp <vector::ReductionOp>()) {
634+ if (auto forOp = redOp.getVector ().getDefiningOp <scf::ForOp>()) {
635+ if (forOp->hasAttr (LoopEmitter::getLoopEmitterLoopAttrName ())) {
636+ rewriter.replaceOp (op, redOp.getVector ());
637+ return success ();
638+ }
639+ }
640+ }
641+ return failure ();
642+ }
643+
631644// / Reduction chain cleanup.
632645// / v = for { }
633- // / s = vsum(v) v = for { }
634- // / u = expand (s) -> for (v) { }
646+ // / s = vsum(v) v = for { }
647+ // / u = broadcast (s) -> for (v) { }
635648// / for (u) { }
636- template <typename VectorOp>
637- struct ReducChainRewriter : public OpRewritePattern <VectorOp> {
649+ struct ReducChainBroadcastRewriter : public OpRewritePattern <vector::BroadcastOp> {
638650public:
639- using OpRewritePattern<VectorOp >::OpRewritePattern;
651+ using OpRewritePattern<vector::BroadcastOp >::OpRewritePattern;
640652
641- LogicalResult matchAndRewrite (VectorOp op,
653+ LogicalResult matchAndRewrite (vector::BroadcastOp op,
642654 PatternRewriter &rewriter) const override {
643- Value inp = op.getSource ();
644- if (auto redOp = inp.getDefiningOp <vector::ReductionOp>()) {
645- if (auto forOp = redOp.getVector ().getDefiningOp <scf::ForOp>()) {
646- if (forOp->hasAttr (LoopEmitter::getLoopEmitterLoopAttrName ())) {
647- rewriter.replaceOp (op, redOp.getVector ());
648- return success ();
649- }
650- }
651- }
652- return failure ();
655+ return cleanReducChain (rewriter, op, op.getSource ());
653656 }
654657};
655658
659+ // / Reduction chain cleanup.
660+ // / v = for { }
661+ // / s = vsum(v) v = for { }
662+ // / u = insert(s) -> for (v) { }
663+ // / for (u) { }
664+ struct ReducChainInsertRewriter : public OpRewritePattern <vector::InsertOp> {
665+ public:
666+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
667+
668+ LogicalResult matchAndRewrite (vector::InsertOp op,
669+ PatternRewriter &rewriter) const override {
670+ return cleanReducChain (rewriter, op, op.getValueToStore ());
671+ }
672+ };
656673} // namespace
657674
658675// ===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
668685 vector::populateVectorStepLoweringPatterns (patterns);
669686 patterns.add <ForOpRewriter>(patterns.getContext (), vectorLength,
670687 enableVLAVectorization, enableSIMDIndex32);
671- patterns.add <ReducChainRewriter<vector::InsertElementOp>,
672- ReducChainRewriter<vector::BroadcastOp>>( patterns.getContext ());
688+ patterns.add <ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
689+ patterns.getContext ());
673690}
0 commit comments