diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a2e368e502737..e09fe1d1126be 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -254,6 +254,16 @@ struct CppEmitter { return operandExpression == emittedExpression; }; + /// Determine whether expression \p expressionOp should be emitted inline, + /// i.e. as part of its user. This function recommends inlining of any + /// expressions that can be inlined unless it is used by another expression, + /// under the assumption that any expression fusion/re-materialization was + /// taken care of by transformations run by the backend. + bool shouldBeInlined(ExpressionOp expressionOp); + + /// This emitter will only emit translation units whos id matches this value. + StringRef willOnlyEmitTu() { return onlyTu; } + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -297,21 +307,22 @@ struct CppEmitter { return lowestPrecedence(); return emittedExpressionPrecedence.back(); } + + /// Determine whether expression \p op should be emitted in a deferred way. + bool hasDeferredEmission(Operation *op); }; } // namespace -/// Determine whether expression \p op should be emitted in a deferred way. -static bool hasDeferredEmission(Operation *op) { +bool CppEmitter::hasDeferredEmission(Operation *op) { + if (llvm::isa_and_nonnull(op)) { + return !shouldUseConstantsAsVariables(); + } + return isa_and_nonnull(op); } -/// Determine whether expression \p expressionOp should be emitted inline, i.e. -/// as part of its user. This function recommends inlining of any expressions -/// that can be inlined unless it is used by another expression, under the -/// assumption that any expression fusion/re-materialization was taken care of -/// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +bool CppEmitter::shouldBeInlined(ExpressionOp expressionOp) { // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -373,6 +384,25 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp) { if (!emitter.shouldUseConstantsAsVariables()) { + std::string out; + llvm::raw_string_ostream ss(out); + + /// Temporary emitter object that writes to our stream instead of the output + /// allowing for the capture and caching of the produced string. + CppEmitter sniffer = CppEmitter(ss, emitter.shouldDeclareVariablesAtTop(), + emitter.willOnlyEmitTu(), + emitter.shouldUseConstantsAsVariables()); + + ss << "("; + if (failed(sniffer.emitType(constantOp.getLoc(), constantOp.getType()))) + return failure(); + ss << ") "; + + if (failed( + sniffer.emitAttribute(constantOp.getLoc(), constantOp.getValue()))) + return failure(); + + emitter.cacheDeferredOpResult(constantOp.getResult(), out); return success(); } @@ -838,7 +868,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { static LogicalResult printOperation(CppEmitter &emitter, emitc::ExpressionOp expressionOp) { - if (shouldBeInlined(expressionOp)) + if (emitter.shouldBeInlined(expressionOp)) return success(); Operation &op = *expressionOp.getOperation(); @@ -892,7 +922,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { dyn_cast_if_present(value.getDefiningOp()); if (!expressionOp) return false; - return shouldBeInlined(expressionOp); + return emitter.shouldBeInlined(expressionOp); }; os << "for ("; @@ -1114,7 +1144,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter, functionOp->walk([&](Operation *op) -> WalkResult { if (isa(op->getParentOp()) || (isa(op) && - shouldBeInlined(cast(op)))) + emitter.shouldBeInlined(cast(op)))) return WalkResult::skip(); for (OpResult result : op->getResults()) { if (failed(emitter.emitVariableDeclaration( @@ -1494,22 +1524,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { LogicalResult CppEmitter::emitOperand(Value value) { Operation *def = value.getDefiningOp(); - if (!shouldUseConstantsAsVariables()) { - if (auto constant = dyn_cast_if_present(def)) { - os << "(("; - - if (failed(emitType(constant.getLoc(), constant.getType()))) { - return failure(); - } - os << ") "; - - if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) { - return failure(); - } - os << ")"; - return success(); - } - } if (isPartOfCurrentExpression(value)) { assert(def && "Expected operand to be defined by an operation"); @@ -1721,11 +1735,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { cacheDeferredOpResult(op.getResult(), op.getValue()); return success(); }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { + .Case([&](auto op) { cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); return success(); }) diff --git a/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir b/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir index 5774bdc47308f..c908ecc460edf 100644 --- a/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir +++ b/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir @@ -11,9 +11,55 @@ func.func @test() { return } +// CPP-DEFAULT-LABEL: void test() { +// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) { +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// ----- + +func.func @test_subscript(%arg0: !emitc.array<4xf32>) -> (!emitc.lvalue) { + %c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %0 = emitc.subscript %arg0[%c0] : (!emitc.array<4xf32>, !emitc.size_t) -> !emitc.lvalue + return %0 : !emitc.lvalue +} +// CPP-DEFAULT-LABEL: float test_subscript(float v1[4]) { +// CPP-DEFAULT-NEXT: return v1[(size_t) 0]; +// CPP-DEFAULT-NEXT: } + +// ----- -// CPP-DEFAULT: void test() { -// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) { +func.func @emitc_switch_ui64() { + %0 = "emitc.constant"(){value = 1 : ui64} : () -> ui64 + + emitc.switch %0 : ui64 + default { + emitc.call_opaque "func2" (%0) : (ui64) -> () + emitc.yield + } + return +} +// CPP-DEFAULT-LABEL: void emitc_switch_ui64() { +// CPP-DEFAULT: switch ((uint64_t) 1) { +// CPP-DEFAULT-NEXT: default: { +// CPP-DEFAULT-NEXT: func2((uint64_t) 1); +// CPP-DEFAULT-NEXT: break; // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; // CPP-DEFAULT-NEXT: } + +// ----- + +func.func @negative_values() { + %1 = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t + %2 = "emitc.constant"() <{value = -3000000000 : index}> : () -> !emitc.ssize_t + + %3 = emitc.add %1, %2 : (!emitc.size_t, !emitc.ssize_t) -> !emitc.ssize_t + + return +} +// CPP-DEFAULT-LABEL: void negative_values() { +// CPP-DEFAULT-NEXT: ssize_t v1 = (size_t) 10 + (ssize_t) -3000000000; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } \ No newline at end of file