Skip to content

Commit a4494ea

Browse files
authored
Fix bug where emitc constants wouldn't be directly emitted in subscripts. (#411)
* Fix bug where emitc constants wouldn't be directly emitted in subscripts. * Use ConstantOp as a deferred op when constantsAsVariables=false
1 parent 63401e3 commit a4494ea

File tree

2 files changed

+90
-34
lines changed

2 files changed

+90
-34
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,16 @@ struct CppEmitter {
254254
return operandExpression == emittedExpression;
255255
};
256256

257+
/// Determine whether expression \p expressionOp should be emitted inline,
258+
/// i.e. as part of its user. This function recommends inlining of any
259+
/// expressions that can be inlined unless it is used by another expression,
260+
/// under the assumption that any expression fusion/re-materialization was
261+
/// taken care of by transformations run by the backend.
262+
bool shouldBeInlined(ExpressionOp expressionOp);
263+
264+
/// This emitter will only emit translation units whos id matches this value.
265+
StringRef willOnlyEmitTu() { return onlyTu; }
266+
257267
private:
258268
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
259269
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
@@ -297,21 +307,22 @@ struct CppEmitter {
297307
return lowestPrecedence();
298308
return emittedExpressionPrecedence.back();
299309
}
310+
311+
/// Determine whether expression \p op should be emitted in a deferred way.
312+
bool hasDeferredEmission(Operation *op);
300313
};
301314
} // namespace
302315

303-
/// Determine whether expression \p op should be emitted in a deferred way.
304-
static bool hasDeferredEmission(Operation *op) {
316+
bool CppEmitter::hasDeferredEmission(Operation *op) {
317+
if (llvm::isa_and_nonnull<emitc::ConstantOp>(op)) {
318+
return !shouldUseConstantsAsVariables();
319+
}
320+
305321
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
306322
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
307323
}
308324

309-
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
310-
/// as part of its user. This function recommends inlining of any expressions
311-
/// that can be inlined unless it is used by another expression, under the
312-
/// assumption that any expression fusion/re-materialization was taken care of
313-
/// by transformations run by the backend.
314-
static bool shouldBeInlined(ExpressionOp expressionOp) {
325+
bool CppEmitter::shouldBeInlined(ExpressionOp expressionOp) {
315326
// Do not inline if expression is marked as such.
316327
if (expressionOp.getDoNotInline())
317328
return false;
@@ -373,6 +384,25 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
373384
static LogicalResult printOperation(CppEmitter &emitter,
374385
emitc::ConstantOp constantOp) {
375386
if (!emitter.shouldUseConstantsAsVariables()) {
387+
std::string out;
388+
llvm::raw_string_ostream ss(out);
389+
390+
/// Temporary emitter object that writes to our stream instead of the output
391+
/// allowing for the capture and caching of the produced string.
392+
CppEmitter sniffer = CppEmitter(ss, emitter.shouldDeclareVariablesAtTop(),
393+
emitter.willOnlyEmitTu(),
394+
emitter.shouldUseConstantsAsVariables());
395+
396+
ss << "(";
397+
if (failed(sniffer.emitType(constantOp.getLoc(), constantOp.getType())))
398+
return failure();
399+
ss << ") ";
400+
401+
if (failed(
402+
sniffer.emitAttribute(constantOp.getLoc(), constantOp.getValue())))
403+
return failure();
404+
405+
emitter.cacheDeferredOpResult(constantOp.getResult(), out);
376406
return success();
377407
}
378408

@@ -838,7 +868,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
838868

839869
static LogicalResult printOperation(CppEmitter &emitter,
840870
emitc::ExpressionOp expressionOp) {
841-
if (shouldBeInlined(expressionOp))
871+
if (emitter.shouldBeInlined(expressionOp))
842872
return success();
843873

844874
Operation &op = *expressionOp.getOperation();
@@ -892,7 +922,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
892922
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
893923
if (!expressionOp)
894924
return false;
895-
return shouldBeInlined(expressionOp);
925+
return emitter.shouldBeInlined(expressionOp);
896926
};
897927

898928
os << "for (";
@@ -1114,7 +1144,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
11141144
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
11151145
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
11161146
(isa<emitc::ExpressionOp>(op) &&
1117-
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1147+
emitter.shouldBeInlined(cast<emitc::ExpressionOp>(op))))
11181148
return WalkResult::skip();
11191149
for (OpResult result : op->getResults()) {
11201150
if (failed(emitter.emitVariableDeclaration(
@@ -1494,22 +1524,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
14941524

14951525
LogicalResult CppEmitter::emitOperand(Value value) {
14961526
Operation *def = value.getDefiningOp();
1497-
if (!shouldUseConstantsAsVariables()) {
1498-
if (auto constant = dyn_cast_if_present<ConstantOp>(def)) {
1499-
os << "((";
1500-
1501-
if (failed(emitType(constant.getLoc(), constant.getType()))) {
1502-
return failure();
1503-
}
1504-
os << ") ";
1505-
1506-
if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
1507-
return failure();
1508-
}
1509-
os << ")";
1510-
return success();
1511-
}
1512-
}
15131527

15141528
if (isPartOfCurrentExpression(value)) {
15151529
assert(def && "Expected operand to be defined by an operation");
@@ -1721,11 +1735,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
17211735
cacheDeferredOpResult(op.getResult(), op.getValue());
17221736
return success();
17231737
})
1724-
.Case<emitc::MemberOp>([&](auto op) {
1725-
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1726-
return success();
1727-
})
1728-
.Case<emitc::MemberOfPtrOp>([&](auto op) {
1738+
.Case<emitc::MemberOp, emitc::MemberOfPtrOp>([&](auto op) {
17291739
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
17301740
return success();
17311741
})

mlir/test/Target/Cpp/emitc-constants-as-variables.mlir

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,55 @@ func.func @test() {
1111

1212
return
1313
}
14+
// CPP-DEFAULT-LABEL: void test() {
15+
// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) {
16+
// CPP-DEFAULT-NEXT: }
17+
// CPP-DEFAULT-NEXT: return;
18+
// CPP-DEFAULT-NEXT: }
19+
20+
// -----
21+
22+
func.func @test_subscript(%arg0: !emitc.array<4xf32>) -> (!emitc.lvalue<f32>) {
23+
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
24+
%0 = emitc.subscript %arg0[%c0] : (!emitc.array<4xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
25+
return %0 : !emitc.lvalue<f32>
26+
}
27+
// CPP-DEFAULT-LABEL: float test_subscript(float v1[4]) {
28+
// CPP-DEFAULT-NEXT: return v1[(size_t) 0];
29+
// CPP-DEFAULT-NEXT: }
30+
31+
// -----
1432

15-
// CPP-DEFAULT: void test() {
16-
// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) {
33+
func.func @emitc_switch_ui64() {
34+
%0 = "emitc.constant"(){value = 1 : ui64} : () -> ui64
35+
36+
emitc.switch %0 : ui64
37+
default {
38+
emitc.call_opaque "func2" (%0) : (ui64) -> ()
39+
emitc.yield
40+
}
41+
return
42+
}
43+
// CPP-DEFAULT-LABEL: void emitc_switch_ui64() {
44+
// CPP-DEFAULT: switch ((uint64_t) 1) {
45+
// CPP-DEFAULT-NEXT: default: {
46+
// CPP-DEFAULT-NEXT: func2((uint64_t) 1);
47+
// CPP-DEFAULT-NEXT: break;
1748
// CPP-DEFAULT-NEXT: }
1849
// CPP-DEFAULT-NEXT: return;
1950
// CPP-DEFAULT-NEXT: }
51+
52+
// -----
53+
54+
func.func @negative_values() {
55+
%1 = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t
56+
%2 = "emitc.constant"() <{value = -3000000000 : index}> : () -> !emitc.ssize_t
57+
58+
%3 = emitc.add %1, %2 : (!emitc.size_t, !emitc.ssize_t) -> !emitc.ssize_t
59+
60+
return
61+
}
62+
// CPP-DEFAULT-LABEL: void negative_values() {
63+
// CPP-DEFAULT-NEXT: ssize_t v1 = (size_t) 10 + (ssize_t) -3000000000;
64+
// CPP-DEFAULT-NEXT: return;
65+
// CPP-DEFAULT-NEXT: }

0 commit comments

Comments
 (0)