@@ -132,28 +132,6 @@ static mlir::FuncOp plutoTransform(mlir::FuncOp f, OpBuilder &rewriter,
132
132
return g;
133
133
}
134
134
135
- static void dedupIndexCast (FuncOp f) {
136
- Block &entry = f.getBlocks ().front ();
137
- llvm::MapVector<Value, Value> argToCast;
138
- SmallVector<Operation *> toErase;
139
- for (auto &op : entry) {
140
- if (auto indexCast = dyn_cast<arith::IndexCastOp>(&op)) {
141
- auto arg = indexCast.getOperand ().dyn_cast <BlockArgument>();
142
- if (argToCast.count (arg)) {
143
- LLVM_DEBUG (dbgs () << " Found duplicated index_cast: " << indexCast
144
- << ' \n ' );
145
- indexCast.replaceAllUsesWith (argToCast.lookup (arg));
146
- toErase.push_back (indexCast);
147
- } else {
148
- argToCast[arg] = indexCast;
149
- }
150
- }
151
- }
152
-
153
- for (auto op : toErase)
154
- op->erase ();
155
- }
156
-
157
135
namespace {
158
136
class PlutoTransformPass
159
137
: public mlir::PassWrapper<PlutoTransformPass,
@@ -183,7 +161,6 @@ class PlutoTransformPass
183
161
184
162
m.walk ([&](mlir::FuncOp f) {
185
163
if (!f->getAttr (" scop.stmt" ) && !f->hasAttr (" scop.ignored" )) {
186
- dedupIndexCast (f);
187
164
funcOps.push_back (f);
188
165
}
189
166
});
@@ -300,10 +277,45 @@ struct PlutoParallelizePass
300
277
};
301
278
} // namespace
302
279
280
+ static void dedupIndexCast (FuncOp f) {
281
+ if (f.getBlocks ().empty ())
282
+ return ;
283
+
284
+ Block &entry = f.getBlocks ().front ();
285
+ llvm::MapVector<Value, Value> argToCast;
286
+ SmallVector<Operation *> toErase;
287
+ for (auto &op : entry) {
288
+ if (auto indexCast = dyn_cast<arith::IndexCastOp>(&op)) {
289
+ auto arg = indexCast.getOperand ().dyn_cast <BlockArgument>();
290
+ if (argToCast.count (arg)) {
291
+ LLVM_DEBUG (dbgs () << " Found duplicated index_cast: " << indexCast
292
+ << ' \n ' );
293
+ indexCast.replaceAllUsesWith (argToCast.lookup (arg));
294
+ toErase.push_back (indexCast);
295
+ } else {
296
+ argToCast[arg] = indexCast;
297
+ }
298
+ }
299
+ }
300
+
301
+ for (auto op : toErase)
302
+ op->erase ();
303
+ }
304
+
305
+ namespace {
306
+ struct DedupIndexCastPass
307
+ : public mlir::PassWrapper<DedupIndexCastPass,
308
+ OperationPass<mlir::FuncOp>> {
309
+ void runOnOperation () override { dedupIndexCast (getOperation ()); }
310
+ };
311
+ } // namespace
312
+
303
313
void polymer::registerPlutoTransformPass () {
304
314
PassPipelineRegistration<PlutoOptPipelineOptions>(
305
315
" pluto-opt" , " Optimization implemented by PLUTO." ,
306
316
[](OpPassManager &pm, const PlutoOptPipelineOptions &pipelineOptions) {
317
+ pm.addPass (std::make_unique<DedupIndexCastPass>());
318
+ pm.addPass (createCanonicalizerPass ());
307
319
pm.addPass (std::make_unique<PlutoTransformPass>(pipelineOptions));
308
320
pm.addPass (createCanonicalizerPass ());
309
321
if (pipelineOptions.generateParallel ) {
0 commit comments