Skip to content

Commit 1a17905

Browse files
committed
Add forward decl for functions
1 parent ea9715b commit 1a17905

File tree

1 file changed

+41
-19
lines changed

1 file changed

+41
-19
lines changed

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
137137
std::map<int32_t, std::string> gotoid2name;
138138
std::map<std::string, std::string> emit_headers;
139139
std::string array_types_decls;
140+
std::string forward_decl_functions;
140141

141142
// Output configuration:
142143
// Use std::string or char*
@@ -493,18 +494,18 @@ R"(#include <stdio.h>
493494
}
494495

495496
// Returns the declaration, no semi colon at the end
496-
std::string get_function_declaration(const ASR::Function_t &x, bool &has_typevar) {
497+
std::string get_function_declaration(const ASR::Function_t &x, bool &has_typevar, bool is_pointer=false) {
497498
template_for_Kokkos.clear();
498499
template_number = 0;
499500
std::string sub, inl, static_attr;
500501

501502
// This helps to check if the function is generic.
502503
// If it is generic we skip the codegen for that function.
503504
has_typevar = false;
504-
if (ASRUtils::get_FunctionType(x)->m_inline) {
505+
if (ASRUtils::get_FunctionType(x)->m_inline && !is_pointer) {
505506
inl = "inline __attribute__((always_inline)) ";
506507
}
507-
if( ASRUtils::get_FunctionType(x)->m_static ) {
508+
if( ASRUtils::get_FunctionType(x)->m_static && !is_pointer) {
508509
static_attr = "static ";
509510
}
510511
if (x.m_return_var) {
@@ -526,30 +527,47 @@ R"(#include <stdio.h>
526527
f_type->m_deftype == ASR::deftypeType::Implementation) {
527528
sym_name = "_xx_internal_" + sym_name + "_xx";
528529
}
529-
std::string func = static_attr + inl + sub + sym_name + "(";
530+
std::string func = static_attr + inl + sub;
531+
if (is_pointer) {
532+
func += "(*" + sym_name + ")(";
533+
} else {
534+
func += sym_name + "(";
535+
}
530536
bracket_open++;
531537
for (size_t i=0; i<x.n_args; i++) {
532-
ASR::Variable_t *arg = ASRUtils::EXPR2VAR(x.m_args[i]);
533-
LCOMPILERS_ASSERT(ASRUtils::is_arg_dummy(arg->m_intent));
534-
if (ASR::is_a<ASR::TypeParameter_t>(*arg->m_type)) {
535-
has_typevar = true;
536-
bracket_open--;
537-
return "";
538-
}
539-
if( is_c ) {
540-
CDeclarationOptions c_decl_options;
541-
c_decl_options.pre_initialise_derived_type = false;
542-
func += self().convert_variable_decl(*arg, &c_decl_options);
538+
ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(
539+
ASR::down_cast<ASR::Var_t>(x.m_args[i])->m_v);
540+
if (ASR::is_a<ASR::Variable_t>(*sym)) {
541+
ASR::Variable_t *arg = ASR::down_cast<ASR::Variable_t>(sym);
542+
LCOMPILERS_ASSERT(ASRUtils::is_arg_dummy(arg->m_intent));
543+
if( is_c ) {
544+
CDeclarationOptions c_decl_options;
545+
c_decl_options.pre_initialise_derived_type = false;
546+
func += self().convert_variable_decl(*arg, &c_decl_options);
547+
} else {
548+
CPPDeclarationOptions cpp_decl_options;
549+
cpp_decl_options.use_static = false;
550+
cpp_decl_options.use_templates_for_arrays = true;
551+
func += self().convert_variable_decl(*arg, &cpp_decl_options);
552+
}
553+
if (ASR::is_a<ASR::TypeParameter_t>(*arg->m_type)) {
554+
has_typevar = true;
555+
bracket_open--;
556+
return "";
557+
}
558+
} else if (ASR::is_a<ASR::Function_t>(*sym)) {
559+
ASR::Function_t *fun = ASR::down_cast<ASR::Function_t>(sym);
560+
func += get_function_declaration(*fun, has_typevar, true);
543561
} else {
544-
CPPDeclarationOptions cpp_decl_options;
545-
cpp_decl_options.use_static = false;
546-
cpp_decl_options.use_templates_for_arrays = true;
547-
func += self().convert_variable_decl(*arg, &cpp_decl_options);
562+
throw CodeGenError("Unsupported function argument");
548563
}
549564
if (i < x.n_args-1) func += ", ";
550565
}
551566
func += ")";
552567
bracket_open--;
568+
if (f_type->m_abi == ASR::abiType::Source) {
569+
forward_decl_functions += func + ";\n";
570+
}
553571
if( is_c || template_for_Kokkos.empty() ) {
554572
return func;
555573
}
@@ -1716,6 +1734,10 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
17161734

17171735
void visit_Var(const ASR::Var_t &x) {
17181736
const ASR::symbol_t *s = ASRUtils::symbol_get_past_external(x.m_v);
1737+
if (ASR::is_a<ASR::Function_t>(*s)) {
1738+
src = ASRUtils::symbol_name(s);
1739+
return;
1740+
}
17191741
ASR::Variable_t* sv = ASR::down_cast<ASR::Variable_t>(s);
17201742
if( (sv->m_intent == ASRUtils::intent_in ||
17211743
sv->m_intent == ASRUtils::intent_inout) &&

0 commit comments

Comments
 (0)