Skip to content

Commit c0240fa

Browse files
committed
Added polymorphic scope, removed unnecessary scopes, refactored code
1 parent 59a3b4c commit c0240fa

File tree

1 file changed

+51
-34
lines changed

1 file changed

+51
-34
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -204,21 +204,41 @@ struct CppEmitter {
204204
/// Whether to map an mlir integer to a unsigned integer in C++.
205205
bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
206206

207-
/// RAII helper function to manage entering/exiting C++ scopes.
207+
/// Abstract RAII helper function to manage entering/exiting C++ scopes.
208208
struct Scope {
209-
Scope(CppEmitter &emitter)
210-
: valueMapperScope(emitter.valueMapper),
211-
blockMapperScope(emitter.blockMapper), emitter(emitter) {
212-
emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
213-
}
214209
~Scope() { emitter.labelInScopeCount.pop(); }
215210

216211
private:
217212
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
218213
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
214+
215+
protected:
216+
Scope(CppEmitter &emitter)
217+
: valueMapperScope(emitter.valueMapper),
218+
blockMapperScope(emitter.blockMapper), emitter(emitter) {
219+
emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
220+
}
219221
CppEmitter &emitter;
220222
};
221223

224+
/// RAII helper function to manage entering/exiting functions, while re-using
225+
/// value names.
226+
struct FunctionScope : Scope {
227+
FunctionScope(CppEmitter &emitter) : Scope(emitter) {
228+
// Re-use value names
229+
emitter.resetValueCounter();
230+
}
231+
};
232+
233+
/// RAII helper function to manage entering/exiting emitc::forOp loops and
234+
/// handle induction variable naming.
235+
struct LoopScope : Scope {
236+
LoopScope(CppEmitter &emitter) : Scope(emitter) {
237+
emitter.increaseLoopNestingLevel();
238+
}
239+
~LoopScope() { emitter.decreaseLoopNestingLevel(); }
240+
};
241+
222242
/// Returns wether the Value is assigned to a C++ variable in the scope.
223243
bool hasValueInScope(Value val);
224244

@@ -264,12 +284,12 @@ struct CppEmitter {
264284
/// This emitter will only emit translation units whos id matches this value.
265285
StringRef willOnlyEmitTu() { return onlyTu; }
266286

267-
/// Reduces stacks and updates value counter
268-
void popStacksAndUpdate();
269-
270287
// Resets the value counter to 0
271288
void resetValueCounter();
272289

290+
// Increases the loop nesting level by 1
291+
void increaseLoopNestingLevel();
292+
273293
// Decreases the loop nesting level by 1
274294
void decreaseLoopNestingLevel();
275295

@@ -297,11 +317,17 @@ struct CppEmitter {
297317
/// Map from block to name of C++ label.
298318
BlockMapper blockMapper;
299319

320+
/// Default values representing outermost scope
321+
llvm::ScopedHashTableScope<Value, std::string> defaultValueMapperScope;
322+
llvm::ScopedHashTableScope<Block *, std::string> defaultBlockMapperScope;
323+
300324
std::stack<int64_t> labelInScopeCount;
301325

326+
/// Keeps track of the amount of nested loops the emitter currently operates
327+
/// in.
302328
uint64_t loopNestingLevel{0};
303329

304-
/// Emitter-level count of created values to enable unique identifiers
330+
/// Emitter-level count of created values to enable unique identifiers.
305331
unsigned int valueCount{0};
306332

307333
/// State of the current expression being emitted.
@@ -922,7 +948,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
922948
}
923949

924950
static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
925-
CppEmitter::Scope scope(emitter);
926951
raw_indented_ostream &os = emitter.ostream();
927952

928953
// Utility function to determine whether a value is an expression that will be
@@ -964,6 +989,8 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
964989
os << ") {\n";
965990
os.indent();
966991

992+
CppEmitter::LoopScope lScope(emitter);
993+
967994
Region &forRegion = forOp.getRegion();
968995
auto regionOps = forRegion.getOps();
969996

@@ -975,8 +1002,6 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
9751002

9761003
os.unindent() << "}";
9771004

978-
emitter.decreaseLoopNestingLevel();
979-
9801005
return success();
9811006
}
9821007

@@ -1052,9 +1077,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
10521077
}
10531078

10541079
static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
1055-
emitter.resetValueCounter();
1056-
CppEmitter::Scope scope(emitter);
1057-
10581080
for (Operation &op : moduleOp) {
10591081
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
10601082
return failure();
@@ -1066,9 +1088,6 @@ static LogicalResult printOperation(CppEmitter &emitter, TranslationUnitOp tu) {
10661088
if (!emitter.shouldEmitTu(tu))
10671089
return success();
10681090

1069-
emitter.resetValueCounter();
1070-
CppEmitter::Scope scope(emitter);
1071-
10721091
for (Operation &op : tu) {
10731092
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
10741093
return failure();
@@ -1231,8 +1250,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
12311250
return functionOp.emitOpError() << "cannot emit array type as result type";
12321251
}
12331252

1234-
emitter.resetValueCounter();
1235-
CppEmitter::Scope scope(emitter);
1253+
CppEmitter::FunctionScope scope(emitter);
12361254
raw_indented_ostream &os = emitter.ostream();
12371255
if (failed(emitter.emitTypes(functionOp.getLoc(),
12381256
functionOp.getFunctionType().getResults())))
@@ -1260,8 +1278,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
12601278
"with multiple blocks needs variables declared at top");
12611279
}
12621280

1263-
emitter.resetValueCounter();
1264-
CppEmitter::Scope scope(emitter);
1281+
CppEmitter::FunctionScope scope(emitter);
12651282
raw_indented_ostream &os = emitter.ostream();
12661283
if (functionOp.getSpecifiers()) {
12671284
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
@@ -1295,8 +1312,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
12951312

12961313
static LogicalResult printOperation(CppEmitter &emitter,
12971314
DeclareFuncOp declareFuncOp) {
1298-
emitter.resetValueCounter();
1299-
CppEmitter::Scope scope(emitter);
13001315
raw_indented_ostream &os = emitter.ostream();
13011316

13021317
auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
@@ -1328,7 +1343,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
13281343
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
13291344
StringRef onlyTu, bool constantsAsVariables)
13301345
: os(os), declareVariablesAtTop(declareVariablesAtTop),
1331-
onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables) {
1346+
onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables),
1347+
defaultValueMapperScope(valueMapper),
1348+
defaultBlockMapperScope(blockMapper) {
13321349
labelInScopeCount.push(0);
13331350
}
13341351

@@ -1359,9 +1376,8 @@ std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
13591376
}
13601377

13611378
void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1362-
if (!valueMapper.count(value)) {
1379+
if (!valueMapper.count(value))
13631380
valueMapper.insert(value, str.str());
1364-
}
13651381
}
13661382

13671383
/// Return the existing or a new name for a Value.
@@ -1383,16 +1399,15 @@ StringRef CppEmitter::getOrCreateName(emitc::ForOp forOp) {
13831399

13841400
if (!valueMapper.count(val)) {
13851401

1386-
int64_t identifier = loopNestingLevel++;
1402+
int64_t identifier = 'i' + loopNestingLevel;
13871403

1388-
char range = 'z' - 'i';
1389-
if (identifier >= 0 && identifier <= range) {
1390-
valueMapper.insert(
1391-
val, formatv("{0}_{1}", (char)(identifier + 'i'), ++valueCount));
1404+
if (identifier >= 'i' && identifier <= 'z') {
1405+
valueMapper.insert(val,
1406+
formatv("{0}_{1}", (char)identifier, ++valueCount));
13921407
} else {
13931408
// If running out of letters, continue with zX
13941409
valueMapper.insert(
1395-
val, formatv("z{0}_{1}", identifier - range - 1, ++valueCount));
1410+
val, formatv("z{0}_{1}", identifier - 'z' - 1, ++valueCount));
13961411
}
13971412
}
13981413
return *valueMapper.begin(val);
@@ -1989,6 +2004,8 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
19892004

19902005
void CppEmitter::resetValueCounter() { valueCount = 0; }
19912006

2007+
void CppEmitter::increaseLoopNestingLevel() { loopNestingLevel++; }
2008+
19922009
void CppEmitter::decreaseLoopNestingLevel() { loopNestingLevel--; }
19932010

19942011
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,

0 commit comments

Comments
 (0)