Skip to content

Commit 7113245

Browse files
kumasentoivanradanov
authored andcommitted
[PlutoTransform] fix symbol alignment
1 parent 8839ea8 commit 7113245

File tree

4 files changed

+146
-8
lines changed

4 files changed

+146
-8
lines changed

tools/polymer/lib/Support/ScopStmt.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ promoteSymbolToTopLevel(mlir::Value val, FlatAffineValueConstraints &domain,
109109
symMap[val] = arg;
110110
}
111111

112+
static void reorderSymbolsByOperandId(FlatAffineValueConstraints &cst) {
113+
// bubble sort
114+
for (unsigned i = cst.getNumDimIds(); i < cst.getNumDimAndSymbolIds(); ++i)
115+
for (unsigned j = i + 1; j < cst.getNumDimAndSymbolIds(); ++j) {
116+
auto fst = cst.getValue(i).cast<BlockArgument>();
117+
auto snd = cst.getValue(j).cast<BlockArgument>();
118+
if (fst.getArgNumber() > snd.getArgNumber())
119+
cst.swapId(i, j);
120+
}
121+
}
122+
112123
void ScopStmtImpl::initializeDomainAndEnclosingOps() {
113124
// Extract the affine for/if ops enclosing the caller and insert them into the
114125
// enclosingOps list.
@@ -140,6 +151,9 @@ void ScopStmtImpl::initializeDomainAndEnclosingOps() {
140151
&symValues);
141152
for (mlir::Value val : symValues)
142153
promoteSymbolToTopLevel(val, domain, symMap);
154+
155+
// Without this things like swapped-bounds.mlir in test cannot work.
156+
reorderSymbolsByOperandId(domain);
143157
}
144158

145159
void ScopStmtImpl::getArgsValueMapping(BlockAndValueMapping &argMap) {

tools/polymer/lib/Target/OpenScop/ConvertToOpenScop.cc

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,33 @@ void OslScopBuilder::buildScopStmtMap(mlir::FuncOp f,
196196
void OslScopBuilder::buildScopContext(OslScop *scop,
197197
OslScop::ScopStmtMap *scopStmtMap,
198198
FlatAffineValueConstraints &ctx) const {
199-
ctx.reset();
199+
LLVM_DEBUG(dbgs() << "--- Building SCoP context ...\n");
200+
201+
// First initialize the symbols of the ctx by the order of arg number.
202+
// This simply aims to make mergeAndAlignIdsWithOthers work.
203+
SmallVector<Value> symbols;
204+
for (const auto &it : *scopStmtMap) {
205+
auto domain = it.second.getDomain();
206+
SmallVector<Value> syms;
207+
domain->getValues(domain->getNumDimIds(), domain->getNumDimAndSymbolIds(),
208+
&syms);
209+
210+
for (Value sym : syms) {
211+
// Find the insertion position.
212+
auto it = symbols.begin();
213+
while (it != symbols.end()) {
214+
auto lhs = it->cast<BlockArgument>();
215+
auto rhs = sym.cast<BlockArgument>();
216+
if (lhs.getArgNumber() >= rhs.getArgNumber())
217+
break;
218+
++it;
219+
}
220+
if (*it != sym)
221+
symbols.insert(it, sym);
222+
}
223+
}
224+
ctx.reset(/*numDims=*/0, /*numSymbols=*/symbols.size());
225+
ctx.setValues(0, symbols.size(), symbols);
200226

201227
// Union with the domains of all Scop statements. We first merge and align the
202228
// IDs of the context and the domain of the scop statement, and then append
@@ -210,6 +236,31 @@ void OslScopBuilder::buildScopContext(OslScop *scop,
210236
ctx.mergeAndAlignIdsWithOther(0, &cst);
211237
ctx.append(cst);
212238
ctx.removeRedundantConstraints();
239+
240+
LLVM_DEBUG(dbgs() << "Statement:\n");
241+
LLVM_DEBUG(it.second.getCaller().dump());
242+
LLVM_DEBUG(it.second.getCallee().dump());
243+
LLVM_DEBUG(dbgs() << "Target domain: \n");
244+
LLVM_DEBUG(domain->dump());
245+
246+
LLVM_DEBUG({
247+
dbgs() << "Domain values: \n";
248+
SmallVector<Value> values;
249+
domain->getValues(0, domain->getNumDimAndSymbolIds(), &values);
250+
for (Value value : values)
251+
dbgs() << " * " << value << '\n';
252+
});
253+
254+
LLVM_DEBUG(dbgs() << "Updated context: \n");
255+
LLVM_DEBUG(ctx.dump());
256+
257+
LLVM_DEBUG({
258+
dbgs() << "Context values: \n";
259+
SmallVector<Value> values;
260+
ctx.getValues(0, ctx.getNumDimAndSymbolIds(), &values);
261+
for (Value value : values)
262+
dbgs() << " * " << value << '\n';
263+
});
213264
}
214265

215266
// Then, create the single context relation in scop.
@@ -221,19 +272,43 @@ void OslScopBuilder::buildScopContext(OslScop *scop,
221272
SmallVector<mlir::Value, 8> symValues;
222273
ctx.getValues(ctx.getNumDimIds(), ctx.getNumDimAndSymbolIds(), &symValues);
223274

275+
// Add and align domain SYMBOL columns.
224276
for (const auto &it : *scopStmtMap) {
225277
FlatAffineValueConstraints *domain = it.second.getDomain();
278+
// For any symbol missing in the domain, add them directly to the end.
279+
for (unsigned i = 0; i < ctx.getNumSymbolIds(); ++i) {
280+
unsigned pos;
281+
if (!domain->findId(symValues[i], &pos)) // insert to the back
282+
domain->appendSymbolId(symValues[i]);
283+
else
284+
LLVM_DEBUG(dbgs() << "Found " << symValues[i] << '\n');
285+
}
226286

287+
// Then do the aligning.
288+
LLVM_DEBUG(domain->dump());
227289
for (unsigned i = 0; i < ctx.getNumSymbolIds(); i++) {
228290
mlir::Value sym = symValues[i];
229291
unsigned pos;
230-
if (domain->findId(sym, &pos)) {
231-
if (pos != i + domain->getNumDimIds())
232-
domain->swapId(i + domain->getNumDimIds(), pos);
233-
} else {
234-
domain->insertSymbolId(i, sym);
235-
}
292+
assert(domain->findId(sym, &pos));
293+
294+
unsigned posAsCtx = i + domain->getNumDimIds();
295+
LLVM_DEBUG(dbgs() << "Swapping " << posAsCtx << " " << pos << "\n");
296+
if (pos != posAsCtx)
297+
domain->swapId(posAsCtx, pos);
236298
}
299+
300+
// for (unsigned i = 0; i < ctx.getNumSymbolIds(); i++) {
301+
// mlir::Value sym = symValues[i];
302+
// unsigned pos;
303+
// // If the symbol can be found in the domain, we put it in the same
304+
// // position as the ctx.
305+
// if (domain->findId(sym, &pos)) {
306+
// if (pos != i + domain->getNumDimIds())
307+
// domain->swapId(i + domain->getNumDimIds(), pos);
308+
// } else {
309+
// domain->insertSymbolId(i, sym);
310+
// }
311+
// }
237312
}
238313
}
239314

tools/polymer/lib/Transforms/PlutoTransform.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ using namespace mlir;
3737
using namespace llvm;
3838
using namespace polymer;
3939

40+
#define DEBUG_TYPE "pluto-opt"
41+
4042
namespace {
4143
struct PlutoOptPipelineOptions
4244
: public mlir::PassPipelineOptions<PlutoOptPipelineOptions> {
@@ -72,6 +74,8 @@ static mlir::FuncOp plutoTransform(mlir::FuncOp f, OpBuilder &rewriter,
7274
bool parallelize = false, bool debug = false,
7375
int cloogf = -1, int cloogl = -1,
7476
bool diamondTiling = false) {
77+
LLVM_DEBUG(dbgs() << "Pluto transforming: \n");
78+
LLVM_DEBUG(f.dump());
7579

7680
PlutoContext *context = pluto_context_alloc();
7781
OslSymbolTable srcTable, dstTable;
@@ -128,6 +132,28 @@ static mlir::FuncOp plutoTransform(mlir::FuncOp f, OpBuilder &rewriter,
128132
return g;
129133
}
130134

135+
static void dedupIndexCast(FuncOp f) {
136+
Block &entry = f.getBlocks().front();
137+
llvm::MapVector<Value, Value> argToCast;
138+
SmallVector<Operation *> toErase;
139+
for (auto &op : entry) {
140+
if (auto indexCast = dyn_cast<arith::IndexCastOp>(&op)) {
141+
auto arg = indexCast.getOperand().dyn_cast<BlockArgument>();
142+
if (argToCast.count(arg)) {
143+
LLVM_DEBUG(dbgs() << "Found duplicated index_cast: " << indexCast
144+
<< '\n');
145+
indexCast.replaceAllUsesWith(argToCast.lookup(arg));
146+
toErase.push_back(indexCast);
147+
} else {
148+
argToCast[arg] = indexCast;
149+
}
150+
}
151+
}
152+
153+
for (auto op : toErase)
154+
op->erase();
155+
}
156+
131157
namespace {
132158
class PlutoTransformPass
133159
: public mlir::PassWrapper<PlutoTransformPass,
@@ -156,8 +182,10 @@ class PlutoTransformPass
156182
llvm::DenseMap<mlir::FuncOp, mlir::FuncOp> funcMap;
157183

158184
m.walk([&](mlir::FuncOp f) {
159-
if (!f->getAttr("scop.stmt") && !f->hasAttr("scop.ignored"))
185+
if (!f->getAttr("scop.stmt") && !f->hasAttr("scop.ignored")) {
186+
dedupIndexCast(f);
160187
funcOps.push_back(f);
188+
}
161189
});
162190

163191
for (mlir::FuncOp f : funcOps)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: polymer-opt %s -pluto-opt | FileCheck %s
2+
func private @S0() attributes {scop.stmt}
3+
func private @S1() attributes {scop.stmt}
4+
5+
func @foo(%N: index, %M: index, %L: index) {
6+
affine.for %i = 0 to %N {
7+
affine.for %j = 0 to %L {
8+
call @S0() : () -> ()
9+
}
10+
affine.for %j = 0 to %M {
11+
affine.for %k = 0 to %L {
12+
call @S1() : () -> ()
13+
}
14+
}
15+
}
16+
return
17+
}
18+
19+
20+
// Just need to check this thing can be transformed.
21+
// CHECK: func @foo

0 commit comments

Comments
 (0)