@@ -74,19 +74,25 @@ class VectorFusionBase {
7474 func::FuncOp func;
7575 // / Type helper class, can help us to get operation type
7676 TypeHelper typehelper;
77+ // / IR rewriter
78+ IRRewriter *rewriter;
7779
7880public:
79- VectorFusionBase () = default ;
80- VectorFusionBase (func::FuncOp & func, HardWareInfo & info)
81- : func(func), typehelper(info) {}
82- VectorFusionBase (VectorFusionBase & base)
83- : func(base.getFunction()), typehelper(base.getHardwareInfo() ) {}
81+ VectorFusionBase (func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter)
82+ : func(func), typehelper( info), rewriter(rewriter) {}
83+ VectorFusionBase (VectorFusionBase &base, IRRewriter *rewriter)
84+ : func(base.getFunction()), typehelper( base.getHardwareInfo()),
85+ rewriter (rewriter ) {}
8486
8587 // / get current function IR
8688 func::FuncOp &getFunction () { return func; }
8789 // / get current hardware info
88- HardWareInfo &getHardwareInfo () { return typehelper.getHardwareInfo (); }
89- TypeHelper &getTypeHelper () { return typehelper; }
90+ HardWareInfo &getHardwareInfo () noexcept {
91+ return typehelper.getHardwareInfo ();
92+ }
93+ TypeHelper &getTypeHelper () noexcept { return typehelper; }
94+ IRRewriter *getRewriter () noexcept { return rewriter; }
95+ void setRewriter (IRRewriter *rewriter) noexcept { this ->rewriter = rewriter; }
9096};
9197
9298// / Group operation fusion strategy class.
@@ -132,17 +138,20 @@ class GroupOperationFusion : public VectorFusionBase {
132138 DenseMap<Value, Value> operandOriginalValue;
133139
134140public:
135- GroupOperationFusion (func::FuncOp &func, HardWareInfo &info)
136- : VectorFusionBase(func, info) {}
141+ GroupOperationFusion (func::FuncOp &func, HardWareInfo &info,
142+ IRRewriter *rewriter)
143+ : VectorFusionBase(func, info, rewriter) {}
137144
138- GroupOperationFusion (GroupOperationFusion &strategy)
139- : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
145+ GroupOperationFusion (GroupOperationFusion &strategy, IRRewriter *rewriter)
146+ : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(),
147+ rewriter),
140148 opGroups (strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps),
141149 opGroupIndexMap(strategy.opGroupIndexMap),
142150 opAnchorPos(strategy.opAnchorPos){};
143151
144- GroupOperationFusion (GroupOperationFusion &&strategy)
145- : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()),
152+ GroupOperationFusion (GroupOperationFusion &&strategy, IRRewriter *rewriter)
153+ : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(),
154+ rewriter),
146155 opGroups(std::move(strategy.opGroups)),
147156 groupMaxSteps(std::move(strategy.groupMaxSteps)),
148157 groupBigestRankVectorType(
@@ -165,9 +174,9 @@ class GroupOperationFusion : public VectorFusionBase {
165174 this ->getFunction () = fusion.getFunction ();
166175 this ->getHardwareInfo () = fusion.getHardwareInfo ();
167176 this ->getTypeHelper () = fusion.getTypeHelper ();
177+ this ->setRewriter (fusion.getRewriter ());
168178 return *this ;
169179 };
170- GroupOperationFusion &operator =(GroupOperationFusion &&) = default ;
171180
172181 // / Get the map which contains each group vector type which has biggest
173182 // / rank.
@@ -275,10 +284,12 @@ class GroupOperationAnalysis {
275284private:
276285 // / vector-based fusion related data
277286 GroupOperationFusion fusionStrategy;
287+ IRRewriter *rewriter;
278288
279289public:
280- GroupOperationAnalysis (func::FuncOp &func, HardWareInfo &info)
281- : fusionStrategy(func, info) {}
290+ GroupOperationAnalysis (func::FuncOp &func, HardWareInfo &info,
291+ IRRewriter *rewriter)
292+ : fusionStrategy(func, info, rewriter), rewriter(rewriter) {}
282293 // / remove the useless operation, due to it result is not require by other
283294 // / operation
284295 void analysisEmptyGroup ();
@@ -288,6 +299,8 @@ class GroupOperationAnalysis {
288299 GroupOperationFusion &getGroupOperationFusion () { return fusionStrategy; }
289300 // / running the vector-based fusion
290301 void run () { fusionStrategy.run (); }
302+ // / get current function rewriter
303+ IRRewriter *getRewriter () { return rewriter; }
291304};
292305} // namespace gc
293306} // namespace mlir
0 commit comments