66from logging import Logger , basicConfig , getLogger
77from os import getenv , environ
88from pathlib import Path
9- from typing import List
9+ from typing import List , Set , Optional
1010
1111
1212logger = getLogger (__name__ ) # type: Logger
1313
14- atcoder_include = re .compile ('#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
1514
16- include_guard = re .compile ('#.*ATCODER_[A-Z_]*_HPP' )
17-
18- lib_path = Path .cwd ()
19-
20- defined = set ()
21-
22- def dfs (f : str ) -> List [str ]:
23- global defined
24- if f in defined :
25- logger .info ('already included {}, skip' .format (f ))
26- return []
27- defined .add (f )
28-
29- logger .info ('include {}' .format (f ))
30-
31- s = open (str (lib_path / f )).read ()
32- result = []
33- for line in s .splitlines ():
34- if include_guard .match (line ):
35- continue
36-
37- m = atcoder_include .match (line )
38- if m :
39- result .extend (dfs (m .group (1 )))
40- continue
41- result .append (line )
42- return result
15+ class Expander :
16+ atcoder_include = re .compile (
17+ '#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
18+
19+ include_guard = re .compile ('#.*ATCODER_[A-Z_]*_HPP' )
20+
21+ def __init__ (self , lib_paths : List [Path ] = None ):
22+ if lib_paths :
23+ self .lib_paths = lib_paths
24+ else :
25+ self .lib_paths = [Path .cwd ()]
26+
27+ included = set () # type: Set[str]
28+
29+ def find_acl (self , acl_name : str ) -> Optional [Path ]:
30+ for lib_path in self .lib_paths :
31+ path = lib_path / acl_name
32+ if path .exists ():
33+ return path
34+ return None
35+
36+ def expand_acl (self , acl_name : str ) -> List [str ]:
37+ if acl_name in self .included :
38+ logger .info ('already included: {}' .format (acl_name ))
39+ return []
40+ self .included .add (acl_name )
41+ logger .info ('include: {}' .format (acl_name ))
42+ acl_path = self .find_acl (acl_name )
43+ if not acl_path :
44+ logger .warning ('cannot find: {}' .format (acl_name ))
45+ raise FileNotFoundError ()
46+
47+ acl_source = open (str (acl_path )).read ()
48+
49+ result = [] # type: List[str]
50+ for line in acl_source .splitlines ():
51+ if self .include_guard .match (line ):
52+ continue
53+
54+ m = self .atcoder_include .match (line )
55+ if m :
56+ result .extend (self .expand_acl (m .group (1 )))
57+ continue
58+ result .append (line )
59+ return result
60+
61+ def expand (self , source : str ) -> str :
62+ self .included = set ()
63+ result = [] # type: List[str]
64+ for line in source .splitlines ():
65+ m = self .atcoder_include .match (line )
66+
67+ if m :
68+ result .extend (self .expand_acl (m .group (1 )))
69+ continue
70+ result .append (line )
71+ return '\n ' .join (result )
4372
4473
4574if __name__ == "__main__" :
@@ -55,22 +84,16 @@ def dfs(f: str) -> List[str]:
5584 parser .add_argument ('--lib' , help = 'Path to Atcoder Library' )
5685 opts = parser .parse_args ()
5786
87+ lib_path = Path .cwd ()
5888 if opts .lib :
5989 lib_path = Path (opts .lib )
6090 elif 'CPLUS_INCLUDE_PATH' in environ :
6191 lib_path = Path (environ ['CPLUS_INCLUDE_PATH' ])
62- s = open (opts .source ).read ()
63-
64- result = []
65- for line in s .splitlines ():
66- m = atcoder_include .match (line )
6792
68- if m :
69- result .extend (dfs (m .group (1 )))
70- continue
71- result .append (line )
93+ expander = Expander ([lib_path ])
94+ source = open (opts .source ).read ()
95+ output = expander .expand (source )
7296
73- output = '\n ' .join (result ) + '\n '
7497 if opts .console :
7598 print (output )
7699 else :
0 commit comments