Skip to content

Commit de361b5

Browse files
committed
Process main_module as every other module
Work towards Program_t only if disable_main is false
1 parent fd38017 commit de361b5

File tree

1 file changed

+113
-112
lines changed

1 file changed

+113
-112
lines changed

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 113 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3958,34 +3958,30 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
39583958
global_scope = current_scope;
39593959

39603960
ASR::Module_t* module_sym = nullptr;
3961-
if (!main_module) {
3962-
// Main module goes directly to TranslationUnit.
3963-
// Every other module goes into a Module.
3964-
SymbolTable *parent_scope = current_scope;
3965-
current_scope = al.make_new<SymbolTable>(parent_scope);
3961+
// Every module goes into a Module_t
3962+
SymbolTable *parent_scope = current_scope;
3963+
current_scope = al.make_new<SymbolTable>(parent_scope);
39663964

3967-
std::string mod_name = module_name;
3968-
ASR::asr_t *tmp1 = ASR::make_Module_t(al, x.base.base.loc,
3969-
/* a_symtab */ current_scope,
3970-
/* a_name */ s2c(al, mod_name),
3971-
nullptr,
3972-
0,
3973-
false, false);
3965+
ASR::asr_t *tmp1 = ASR::make_Module_t(al, x.base.base.loc,
3966+
/* a_symtab */ current_scope,
3967+
/* a_name */ s2c(al, module_name),
3968+
nullptr,
3969+
0,
3970+
false, false);
39743971

3975-
if (parent_scope->get_scope().find(mod_name) != parent_scope->get_scope().end()) {
3976-
throw SemanticError("Module '" + mod_name + "' already defined", tmp1->loc);
3977-
}
3978-
module_sym = ASR::down_cast<ASR::Module_t>(ASR::down_cast<ASR::symbol_t>(tmp1));
3979-
parent_scope->add_symbol(mod_name, ASR::down_cast<ASR::symbol_t>(tmp1));
3972+
if (parent_scope->get_scope().find(module_name) != parent_scope->get_scope().end()) {
3973+
throw SemanticError("Module '" + module_name + "' already defined", tmp1->loc);
39803974
}
3975+
module_sym = ASR::down_cast<ASR::Module_t>(ASR::down_cast<ASR::symbol_t>(tmp1));
3976+
parent_scope->add_symbol(module_name, ASR::down_cast<ASR::symbol_t>(tmp1));
39813977
current_module_dependencies.reserve(al, 1);
39823978
for (size_t i=0; i<x.n_body; i++) {
39833979
visit_stmt(*x.m_body[i]);
39843980
}
3985-
if( module_sym ) {
3986-
module_sym->m_dependencies = current_module_dependencies.p;
3987-
module_sym->n_dependencies = current_module_dependencies.size();
3988-
}
3981+
3982+
LCOMPILERS_ASSERT(module_sym != nullptr);
3983+
module_sym->m_dependencies = current_module_dependencies.p;
3984+
module_sym->n_dependencies = current_module_dependencies.size();
39893985
if (!overload_defs.empty()) {
39903986
create_GenericProcedure(x.base.base.loc);
39913987
}
@@ -4745,12 +4741,12 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
47454741
ASR::symbol_t* module_sym = nullptr;
47464742
ASR::Module_t* mod = nullptr;
47474743

4748-
if (!main_module) {
4749-
module_sym = current_scope->get_symbol(module_name);
4750-
mod = ASR::down_cast<ASR::Module_t>(module_sym);
4751-
current_scope = mod->m_symtab;
4752-
LCOMPILERS_ASSERT(current_scope != nullptr);
4753-
}
4744+
LCOMPILERS_ASSERT(module_name.size() > 0);
4745+
module_sym = current_scope->get_symbol(module_name);
4746+
mod = ASR::down_cast<ASR::Module_t>(module_sym);
4747+
LCOMPILERS_ASSERT(mod != nullptr);
4748+
current_scope = mod->m_symtab;
4749+
LCOMPILERS_ASSERT(current_scope != nullptr);
47544750

47554751
Vec<ASR::asr_t*> items;
47564752
items.reserve(al, 4);
@@ -4770,68 +4766,59 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
47704766
}
47714767
}
47724768

4773-
if( mod ) {
4774-
for( size_t i = 0; i < mod->n_dependencies; i++ ) {
4775-
current_module_dependencies.push_back(al, mod->m_dependencies[i]);
4776-
}
4777-
mod->m_dependencies = current_module_dependencies.p;
4778-
mod->n_dependencies = current_module_dependencies.n;
4779-
4780-
if (global_init.n > 0) {
4781-
// unit->m_items is used and set to nullptr in the
4782-
// `pass_wrap_global_stmts_into_function` pass
4783-
unit->m_items = global_init.p;
4784-
unit->n_items = global_init.size();
4785-
std::string func_name = "global_initializer";
4786-
LCompilers::PassOptions pass_options;
4787-
pass_options.run_fun = func_name;
4788-
pass_wrap_global_stmts(al, *unit, pass_options);
4789-
4790-
ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name);
4791-
if (f_sym) {
4792-
// Add the `global_initilaizer` function into the
4793-
// module and later call this function to initialize the
4794-
// global variables like list, ...
4795-
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4796-
f->m_symtab->parent = mod->m_symtab;
4797-
mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f);
4798-
// Erase the function in TranslationUnit
4799-
unit->m_global_scope->erase_symbol(func_name);
4800-
}
4801-
global_init.p = nullptr;
4802-
global_init.n = 0;
4803-
}
4804-
4805-
if (items.n > 0) {
4806-
unit->m_items = items.p;
4807-
unit->n_items = items.size();
4808-
std::string func_name = "global_statements";
4809-
// Wrap all the global statements into a Function
4810-
LCompilers::PassOptions pass_options;
4811-
pass_options.run_fun = func_name;
4812-
pass_wrap_global_stmts(al, *unit, pass_options);
4813-
4814-
ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name);
4815-
if (f_sym) {
4816-
// Add the `global_statements` function into the
4817-
// module and later call this function to execute the
4818-
// global_statements
4819-
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4820-
f->m_symtab->parent = mod->m_symtab;
4821-
mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f);
4822-
// Erase the function in TranslationUnit
4823-
unit->m_global_scope->erase_symbol(func_name);
4824-
}
4825-
items.p = nullptr;
4826-
items.n = 0;
4827-
}
4828-
} else {
4829-
// It is main_module
4830-
for (auto item:items) {
4831-
global_init.push_back(al, item);
4832-
}
4769+
for( size_t i = 0; i < mod->n_dependencies; i++ ) {
4770+
current_module_dependencies.push_back(al, mod->m_dependencies[i]);
4771+
}
4772+
mod->m_dependencies = current_module_dependencies.p;
4773+
mod->n_dependencies = current_module_dependencies.n;
4774+
4775+
if (global_init.n > 0) {
4776+
// unit->m_items is used and set to nullptr in the
4777+
// `pass_wrap_global_stmts_into_function` pass
48334778
unit->m_items = global_init.p;
48344779
unit->n_items = global_init.size();
4780+
std::string func_name = "global_initializer";
4781+
LCompilers::PassOptions pass_options;
4782+
pass_options.run_fun = func_name;
4783+
pass_wrap_global_stmts(al, *unit, pass_options);
4784+
4785+
ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name);
4786+
if (f_sym) {
4787+
// Add the `global_initilaizer` function into the
4788+
// module and later call this function to initialize the
4789+
// global variables like list, ...
4790+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4791+
f->m_symtab->parent = mod->m_symtab;
4792+
mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f);
4793+
// Erase the function in TranslationUnit
4794+
unit->m_global_scope->erase_symbol(func_name);
4795+
}
4796+
global_init.p = nullptr;
4797+
global_init.n = 0;
4798+
}
4799+
4800+
if (items.n > 0) {
4801+
unit->m_items = items.p;
4802+
unit->n_items = items.size();
4803+
std::string func_name = "global_statements";
4804+
// Wrap all the global statements into a Function
4805+
LCompilers::PassOptions pass_options;
4806+
pass_options.run_fun = func_name;
4807+
pass_wrap_global_stmts(al, *unit, pass_options);
4808+
4809+
ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name);
4810+
if (f_sym) {
4811+
// Add the `global_statements` function into the
4812+
// module and later call this function to execute the
4813+
// global_statements
4814+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4815+
f->m_symtab->parent = mod->m_symtab;
4816+
mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f);
4817+
// Erase the function in TranslationUnit
4818+
unit->m_global_scope->erase_symbol(func_name);
4819+
}
4820+
items.p = nullptr;
4821+
items.n = 0;
48354822
}
48364823

48374824
tmp = asr;
@@ -7701,35 +7688,49 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al, LocationManager
77017688
#endif
77027689
}
77037690

7704-
if (main_module) {
7691+
if (main_module && !compiler_options.disable_main) {
77057692
// If it is a main module, turn it into a program
77067693
// Note: we can modify this behavior for interactive mode later
7707-
LCompilers::PassOptions pass_options;
7708-
pass_options.disable_main = compiler_options.disable_main;
7709-
if (compiler_options.disable_main) {
7710-
if (tu->n_items > 0) {
7711-
diagnostics.add(diag::Diagnostic(
7712-
"The script is invoked as the main module and it has code to execute,\n"
7713-
"but `--disable-main` was passed so no code was generated for `main`.\n"
7714-
"We are removing all global executable code from ASR.",
7715-
diag::Level::Warning, diag::Stage::Semantic, {})
7716-
);
7717-
// We have to remove the code
7718-
tu->m_items=nullptr;
7719-
tu->n_items=0;
7720-
// LCOMPILERS_ASSERT(asr_verify(*tu));
7721-
}
7722-
} else {
7723-
pass_options.run_fun = "_lpython_main_program";
7724-
pass_options.runtime_library_dir = get_runtime_library_dir();
7694+
7695+
Vec<ASR::stmt_t*> prog_body;
7696+
prog_body.reserve(al, 1);
7697+
SetChar prog_dep;
7698+
prog_dep.reserve(al, 1);
7699+
SymbolTable *program_scope = al.make_new<SymbolTable>(tu->m_global_scope);
7700+
7701+
std::string mod_name = "__main__";
7702+
ASR::symbol_t *mod_sym = tu->m_global_scope->resolve_symbol(mod_name);
7703+
LCOMPILERS_ASSERT(mod_sym);
7704+
ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(mod_sym);
7705+
LCOMPILERS_ASSERT(mod);
7706+
std::vector<ASR::asr_t*> tmp_vec;
7707+
get_calls_to_global_init_and_stmts(al, tu->base.base.loc, program_scope, mod, tmp_vec);
7708+
7709+
for (auto i:tmp_vec) {
7710+
prog_body.push_back(al, ASRUtils::STMT(i));
7711+
}
7712+
7713+
if (prog_body.size() > 0) {
7714+
prog_dep.push_back(al, s2c(al, mod_name));
77257715
}
7726-
pass_wrap_global_stmts_program(al, *tu, pass_options);
7716+
7717+
std::string prog_name = "main_program";
7718+
ASR::asr_t *prog = ASR::make_Program_t(
7719+
al, tu->base.base.loc,
7720+
/* a_symtab */ program_scope,
7721+
/* a_name */ s2c(al, prog_name),
7722+
prog_dep.p,
7723+
prog_dep.n,
7724+
/* a_body */ prog_body.p,
7725+
/* n_body */ prog_body.n);
7726+
tu->m_global_scope->add_symbol(prog_name, ASR::down_cast<ASR::symbol_t>(prog));
7727+
77277728
#if defined(WITH_LFORTRAN_ASSERT)
7728-
diag::Diagnostics diagnostics;
7729-
if (!asr_verify(*tu, true, diagnostics)) {
7730-
std::cerr << diagnostics.render2();
7731-
throw LCompilersException("Verify failed");
7732-
};
7729+
diag::Diagnostics diagnostics;
7730+
if (!asr_verify(*tu, true, diagnostics)) {
7731+
std::cerr << diagnostics.render2();
7732+
throw LCompilersException("Verify failed");
7733+
};
77337734
#endif
77347735
}
77357736

0 commit comments

Comments
 (0)