@@ -3958,34 +3958,30 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
39583958 global_scope = current_scope;
39593959
39603960 ASR::Module_t* module_sym = nullptr ;
3961- if (!main_module) {
3962- // Main module goes directly to TranslationUnit.
3963- // Every other module goes into a Module.
3964- SymbolTable *parent_scope = current_scope;
3965- current_scope = al.make_new <SymbolTable>(parent_scope);
3961+ // Every module goes into a Module_t
3962+ SymbolTable *parent_scope = current_scope;
3963+ current_scope = al.make_new <SymbolTable>(parent_scope);
39663964
3967- std::string mod_name = module_name;
3968- ASR::asr_t *tmp1 = ASR::make_Module_t (al, x.base .base .loc ,
3969- /* a_symtab */ current_scope,
3970- /* a_name */ s2c (al, mod_name),
3971- nullptr ,
3972- 0 ,
3973- false , false );
3965+ ASR::asr_t *tmp1 = ASR::make_Module_t (al, x.base .base .loc ,
3966+ /* a_symtab */ current_scope,
3967+ /* a_name */ s2c (al, module_name),
3968+ nullptr ,
3969+ 0 ,
3970+ false , false );
39743971
3975- if (parent_scope->get_scope ().find (mod_name) != parent_scope->get_scope ().end ()) {
3976- throw SemanticError (" Module '" + mod_name + " ' already defined" , tmp1->loc );
3977- }
3978- module_sym = ASR::down_cast<ASR::Module_t>(ASR::down_cast<ASR::symbol_t >(tmp1));
3979- parent_scope->add_symbol (mod_name, ASR::down_cast<ASR::symbol_t >(tmp1));
3972+ if (parent_scope->get_scope ().find (module_name) != parent_scope->get_scope ().end ()) {
3973+ throw SemanticError (" Module '" + module_name + " ' already defined" , tmp1->loc );
39803974 }
3975+ module_sym = ASR::down_cast<ASR::Module_t>(ASR::down_cast<ASR::symbol_t >(tmp1));
3976+ parent_scope->add_symbol (module_name, ASR::down_cast<ASR::symbol_t >(tmp1));
39813977 current_module_dependencies.reserve (al, 1 );
39823978 for (size_t i=0 ; i<x.n_body ; i++) {
39833979 visit_stmt (*x.m_body [i]);
39843980 }
3985- if ( module_sym ) {
3986- module_sym-> m_dependencies = current_module_dependencies. p ;
3987- module_sym->n_dependencies = current_module_dependencies.size () ;
3988- }
3981+
3982+ LCOMPILERS_ASSERT ( module_sym != nullptr ) ;
3983+ module_sym->m_dependencies = current_module_dependencies.p ;
3984+ module_sym-> n_dependencies = current_module_dependencies. size ();
39893985 if (!overload_defs.empty ()) {
39903986 create_GenericProcedure (x.base .base .loc );
39913987 }
@@ -4745,12 +4741,12 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
47454741 ASR::symbol_t * module_sym = nullptr ;
47464742 ASR::Module_t* mod = nullptr ;
47474743
4748- if (!main_module) {
4749- module_sym = current_scope->get_symbol (module_name);
4750- mod = ASR::down_cast<ASR::Module_t>(module_sym);
4751- current_scope = mod-> m_symtab ;
4752- LCOMPILERS_ASSERT ( current_scope != nullptr ) ;
4753- }
4744+ LCOMPILERS_ASSERT (module_name. size () > 0 );
4745+ module_sym = current_scope->get_symbol (module_name);
4746+ mod = ASR::down_cast<ASR::Module_t>(module_sym);
4747+ LCOMPILERS_ASSERT (mod != nullptr ) ;
4748+ current_scope = mod-> m_symtab ;
4749+ LCOMPILERS_ASSERT (current_scope != nullptr );
47544750
47554751 Vec<ASR::asr_t *> items;
47564752 items.reserve (al, 4 );
@@ -4770,68 +4766,59 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
47704766 }
47714767 }
47724768
4773- if ( mod ) {
4774- for ( size_t i = 0 ; i < mod->n_dependencies ; i++ ) {
4775- current_module_dependencies.push_back (al, mod->m_dependencies [i]);
4776- }
4777- mod->m_dependencies = current_module_dependencies.p ;
4778- mod->n_dependencies = current_module_dependencies.n ;
4779-
4780- if (global_init.n > 0 ) {
4781- // unit->m_items is used and set to nullptr in the
4782- // `pass_wrap_global_stmts_into_function` pass
4783- unit->m_items = global_init.p ;
4784- unit->n_items = global_init.size ();
4785- std::string func_name = " global_initializer" ;
4786- LCompilers::PassOptions pass_options;
4787- pass_options.run_fun = func_name;
4788- pass_wrap_global_stmts (al, *unit, pass_options);
4789-
4790- ASR::symbol_t *f_sym = unit->m_global_scope ->get_symbol (func_name);
4791- if (f_sym) {
4792- // Add the `global_initilaizer` function into the
4793- // module and later call this function to initialize the
4794- // global variables like list, ...
4795- ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4796- f->m_symtab ->parent = mod->m_symtab ;
4797- mod->m_symtab ->add_symbol (func_name, (ASR::symbol_t *) f);
4798- // Erase the function in TranslationUnit
4799- unit->m_global_scope ->erase_symbol (func_name);
4800- }
4801- global_init.p = nullptr ;
4802- global_init.n = 0 ;
4803- }
4804-
4805- if (items.n > 0 ) {
4806- unit->m_items = items.p ;
4807- unit->n_items = items.size ();
4808- std::string func_name = " global_statements" ;
4809- // Wrap all the global statements into a Function
4810- LCompilers::PassOptions pass_options;
4811- pass_options.run_fun = func_name;
4812- pass_wrap_global_stmts (al, *unit, pass_options);
4813-
4814- ASR::symbol_t *f_sym = unit->m_global_scope ->get_symbol (func_name);
4815- if (f_sym) {
4816- // Add the `global_statements` function into the
4817- // module and later call this function to execute the
4818- // global_statements
4819- ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4820- f->m_symtab ->parent = mod->m_symtab ;
4821- mod->m_symtab ->add_symbol (func_name, (ASR::symbol_t *) f);
4822- // Erase the function in TranslationUnit
4823- unit->m_global_scope ->erase_symbol (func_name);
4824- }
4825- items.p = nullptr ;
4826- items.n = 0 ;
4827- }
4828- } else {
4829- // It is main_module
4830- for (auto item:items) {
4831- global_init.push_back (al, item);
4832- }
4769+ for ( size_t i = 0 ; i < mod->n_dependencies ; i++ ) {
4770+ current_module_dependencies.push_back (al, mod->m_dependencies [i]);
4771+ }
4772+ mod->m_dependencies = current_module_dependencies.p ;
4773+ mod->n_dependencies = current_module_dependencies.n ;
4774+
4775+ if (global_init.n > 0 ) {
4776+ // unit->m_items is used and set to nullptr in the
4777+ // `pass_wrap_global_stmts_into_function` pass
48334778 unit->m_items = global_init.p ;
48344779 unit->n_items = global_init.size ();
4780+ std::string func_name = " global_initializer" ;
4781+ LCompilers::PassOptions pass_options;
4782+ pass_options.run_fun = func_name;
4783+ pass_wrap_global_stmts (al, *unit, pass_options);
4784+
4785+ ASR::symbol_t *f_sym = unit->m_global_scope ->get_symbol (func_name);
4786+ if (f_sym) {
4787+ // Add the `global_initilaizer` function into the
4788+ // module and later call this function to initialize the
4789+ // global variables like list, ...
4790+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4791+ f->m_symtab ->parent = mod->m_symtab ;
4792+ mod->m_symtab ->add_symbol (func_name, (ASR::symbol_t *) f);
4793+ // Erase the function in TranslationUnit
4794+ unit->m_global_scope ->erase_symbol (func_name);
4795+ }
4796+ global_init.p = nullptr ;
4797+ global_init.n = 0 ;
4798+ }
4799+
4800+ if (items.n > 0 ) {
4801+ unit->m_items = items.p ;
4802+ unit->n_items = items.size ();
4803+ std::string func_name = " global_statements" ;
4804+ // Wrap all the global statements into a Function
4805+ LCompilers::PassOptions pass_options;
4806+ pass_options.run_fun = func_name;
4807+ pass_wrap_global_stmts (al, *unit, pass_options);
4808+
4809+ ASR::symbol_t *f_sym = unit->m_global_scope ->get_symbol (func_name);
4810+ if (f_sym) {
4811+ // Add the `global_statements` function into the
4812+ // module and later call this function to execute the
4813+ // global_statements
4814+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(f_sym);
4815+ f->m_symtab ->parent = mod->m_symtab ;
4816+ mod->m_symtab ->add_symbol (func_name, (ASR::symbol_t *) f);
4817+ // Erase the function in TranslationUnit
4818+ unit->m_global_scope ->erase_symbol (func_name);
4819+ }
4820+ items.p = nullptr ;
4821+ items.n = 0 ;
48354822 }
48364823
48374824 tmp = asr;
@@ -7701,35 +7688,49 @@ Result<ASR::TranslationUnit_t*> python_ast_to_asr(Allocator &al, LocationManager
77017688#endif
77027689 }
77037690
7704- if (main_module) {
7691+ if (main_module && !compiler_options. disable_main ) {
77057692 // If it is a main module, turn it into a program
77067693 // Note: we can modify this behavior for interactive mode later
7707- LCompilers::PassOptions pass_options;
7708- pass_options.disable_main = compiler_options.disable_main ;
7709- if (compiler_options.disable_main ) {
7710- if (tu->n_items > 0 ) {
7711- diagnostics.add (diag::Diagnostic (
7712- " The script is invoked as the main module and it has code to execute,\n "
7713- " but `--disable-main` was passed so no code was generated for `main`.\n "
7714- " We are removing all global executable code from ASR." ,
7715- diag::Level::Warning, diag::Stage::Semantic, {})
7716- );
7717- // We have to remove the code
7718- tu->m_items =nullptr ;
7719- tu->n_items =0 ;
7720- // LCOMPILERS_ASSERT(asr_verify(*tu));
7721- }
7722- } else {
7723- pass_options.run_fun = " _lpython_main_program" ;
7724- pass_options.runtime_library_dir = get_runtime_library_dir ();
7694+
7695+ Vec<ASR::stmt_t *> prog_body;
7696+ prog_body.reserve (al, 1 );
7697+ SetChar prog_dep;
7698+ prog_dep.reserve (al, 1 );
7699+ SymbolTable *program_scope = al.make_new <SymbolTable>(tu->m_global_scope );
7700+
7701+ std::string mod_name = " __main__" ;
7702+ ASR::symbol_t *mod_sym = tu->m_global_scope ->resolve_symbol (mod_name);
7703+ LCOMPILERS_ASSERT (mod_sym);
7704+ ASR::Module_t *mod = ASR::down_cast<ASR::Module_t>(mod_sym);
7705+ LCOMPILERS_ASSERT (mod);
7706+ std::vector<ASR::asr_t *> tmp_vec;
7707+ get_calls_to_global_init_and_stmts (al, tu->base .base .loc , program_scope, mod, tmp_vec);
7708+
7709+ for (auto i:tmp_vec) {
7710+ prog_body.push_back (al, ASRUtils::STMT (i));
7711+ }
7712+
7713+ if (prog_body.size () > 0 ) {
7714+ prog_dep.push_back (al, s2c (al, mod_name));
77257715 }
7726- pass_wrap_global_stmts_program (al, *tu, pass_options);
7716+
7717+ std::string prog_name = " main_program" ;
7718+ ASR::asr_t *prog = ASR::make_Program_t (
7719+ al, tu->base .base .loc ,
7720+ /* a_symtab */ program_scope,
7721+ /* a_name */ s2c (al, prog_name),
7722+ prog_dep.p ,
7723+ prog_dep.n ,
7724+ /* a_body */ prog_body.p ,
7725+ /* n_body */ prog_body.n );
7726+ tu->m_global_scope ->add_symbol (prog_name, ASR::down_cast<ASR::symbol_t >(prog));
7727+
77277728 #if defined(WITH_LFORTRAN_ASSERT)
7728- diag::Diagnostics diagnostics;
7729- if (!asr_verify (*tu, true , diagnostics)) {
7730- std::cerr << diagnostics.render2 ();
7731- throw LCompilersException (" Verify failed" );
7732- };
7729+ diag::Diagnostics diagnostics;
7730+ if (!asr_verify (*tu, true , diagnostics)) {
7731+ std::cerr << diagnostics.render2 ();
7732+ throw LCompilersException (" Verify failed" );
7733+ };
77337734 #endif
77347735 }
77357736
0 commit comments