@@ -31,7 +31,6 @@ class Path(BasePath):
3131 :param create: create the directory if it doesn't exist
3232 :param touch: create the file if it doesn't exist (mutually exclusive with 'create')
3333 """
34-
3534 _flavour = BasePath ()._flavour # fix to AttributeError
3635
3736 def __new__ (cls , * parts , ** kwargs ):
@@ -588,7 +587,7 @@ def todo(self):
588587
589588class PythonPath (Path ):
590589 """ Path extension for handling the dynamic import of Python modules. """
591- def __init__ (self , path , remove_cache = False ):
590+ def __init__ (self , path , remove_cache = False , ignore_main = True ):
592591 super (PythonPath , self ).__init__ ()
593592 if remove_cache :
594593 for p in self .walk (filter_func = lambda x : x .is_dir () and x .basename == "__pycache__" ):
@@ -608,14 +607,35 @@ def __init__(self, path, remove_cache=False):
608607 self .modules .append (p .module )
609608 else :
610609 self .loaded = False
610+ name = self .dirname .stem if self .stem == "__init__" else self .stem
611611 if self .extension in [".py" , ".pyc" ]:
612612 try :
613613 loader_cls = ["SourcelessFileLoader" , "SourceFileLoader" ][self .extension == ".py" ]
614- loader = getattr (importlib .machinery , loader_cls )(self . stem , str (self ))
615- spec = importlib .util .spec_from_file_location (self . stem , str (self ), loader = loader )
614+ loader = getattr (importlib .machinery , loader_cls )(name , str (self ))
615+ spec = importlib .util .spec_from_file_location (name , str (self ), loader = loader )
616616 self .module = importlib .util .module_from_spec (spec )
617617 sys .modules [self .module .__name__ ] = self .module
618- loader .exec_module (self .module )
618+ if self .extension == ".py" and ignore_main :
619+ from ast import parse , Compare , Constant , Eq , If , Name
620+ from types import ModuleType
621+ with self .open ('r' , encoding = "utf-8" ) as f :
622+ source = f .read ()
623+ tree = parse (source , filename = str (self ))
624+ filtered_body = []
625+ for node in tree .body :
626+ if isinstance (node , If ):
627+ test = node .test
628+ if isinstance (test , Compare ) and isinstance (test .left , Name ) and \
629+ test .left .id == "__name__" and any (isinstance (op , Eq ) for op in test .ops ) and \
630+ any (isinstance (comp , Constant ) and \
631+ comp .value == "__main__" for comp in test .comparators ):
632+ continue # skip this `if __name__ == '__main__':` block
633+ filtered_body .append (node )
634+ tree .body = filtered_body
635+ code = compile (tree , str (self ), "exec" )
636+ exec (code , self .module .__dict__ )
637+ else :
638+ loader .exec_module (self .module )
619639 self .loaded = True
620640 except (ImportError , NameError , SyntaxError , ValueError ):
621641 raise
0 commit comments