11from algorithms .configuration .configuration import Configuration
22from algorithms .algorithm_manager import AlgorithmManager
3+ from maps .map_manager import MapManager
34from algorithms .lstm .trainer import Trainer
45from analyzer .analyzer import Analyzer
56from generator .generator import Generator
@@ -151,7 +152,7 @@ def configure_common(config, args) -> bool:
151152 print ("Available algorithms:" )
152153 for key in AlgorithmManager .builtins .keys ():
153154 print (f" { key } " )
154- print ("Or specify your own file with a class that inherits from Algorithm" )
155+ print ("Or specify your own file that contains a class that inherits from Algorithm" )
155156 sys .exit (0 )
156157
157158 if args .algorithms :
@@ -162,7 +163,7 @@ def configure_common(config, args) -> bool:
162163 valid_str = "," .join ('"' + a + '"' for a in AlgorithmManager .builtins .keys ())
163164 print (f"Invalid algorithm(s) specified: { invalid_str } " , file = sys .stderr )
164165 print (f"Available algorithms: { valid_str } " , file = sys .stderr )
165- print ("Or specify your own file with a class that inherits from Algorithm" , file = sys .stderr )
166+ print ("Or specify your own file that contains a class that inherits from Algorithm" , file = sys .stderr )
166167 return False
167168
168169 algorithms = list (flatten (algorithms , depth = 1 ))
@@ -179,6 +180,46 @@ def configure_common(config, args) -> bool:
179180
180181 config .algorithms = algorithms
181182
183+ if args .list_maps :
184+ print ("Available maps:" )
185+ for key in MapManager .builtins .keys ():
186+ print (f" { key } " )
187+ print ("Can also specify a custom map," )
188+ print (" (1) cached map stored in Maps" )
189+ print (" (2) external file that contains a global variable with type that inherits from Map" )
190+ sys .exit (0 )
191+
192+ if args .maps :
193+ maps = MapManager .load_all (args .maps )
194+ if not all (maps ):
195+ invalid_maps = [args .maps [i ] for i in range (len (maps )) if not maps [i ]]
196+ invalid_str = "," .join ('"' + a + '"' for a in invalid_maps )
197+ valid_str = "," .join ('"' + a + '"' for a in MapManager .builtins .keys ())
198+ print (f"Invalid map(s) specified: { invalid_str } " , file = sys .stderr )
199+ print (f"Available maps: { valid_str } " , file = sys .stderr )
200+ print ("Can also specify a custom map," , file = sys .stderr )
201+ print (" (1) cached map stored in Maps" , file = sys .stderr )
202+ print (" (2) external file that contains a global variable with type that inherits from Map" , file = sys .stderr )
203+ return False
204+
205+ maps = list (flatten (maps , depth = 1 ))
206+
207+ # name uniqueness
208+ names = [a [0 ] for a in maps ]
209+ if len (set (names )) != len (names ):
210+ print ("Name conflict detected in custom map list:" , names , file = sys .stderr )
211+ return False
212+
213+ maps = dict (maps )
214+ if args .include_default_builtin_maps or args .include_all_builtin_maps :
215+ maps .update (MapManager .builtins )
216+ if args .include_all_builtin_maps :
217+ maps .update (MapManager .cached_builtins )
218+
219+ config .maps = maps
220+ elif args .include_all_builtin_maps :
221+ config .maps .update (MapManager .cached_builtins )
222+
182223 if args .deterministic :
183224 random .seed (args .std_random_seed )
184225 torch .manual_seed (args .torch_random_seed )
@@ -225,9 +266,17 @@ def main() -> bool:
225266 parser .add_argument ("--dims" , type = int , help = "[generator|analyzer] number of dimensions" , default = 3 )
226267
227268 parser .add_argument ("--algorithms" , help = "[visualiser|analyzer] algorithms to load (either built-in algorithm name or module file path)" , nargs = "+" )
228- parser .add_argument ("--include-builtin-algorithms" , action = 'store_true' , help = "include all builtin algorithms even when a custom list is provided via '--algorithms'" )
269+ parser .add_argument ("--include-builtin-algorithms" , action = 'store_true' ,
270+ help = "[visualiser|analyzer] include all builtin algorithms even when a custom list is provided via '--algorithms'" )
229271 parser .add_argument ("--list-algorithms" , action = "store_true" , help = "[visualiser|analyzer] output list of available built-in algorithms" )
230272
273+ parser .add_argument ("--maps" , help = "[visualiser|analyzer|trainer] maps to load (either built-in map name or module file path)" , nargs = "+" )
274+ parser .add_argument ("--include-all-builtin-maps" , action = 'store_true' ,
275+ help = "[visualiser|analyzer|trainer] include all builtin maps (includes all cached maps) even when a custom list is provided via '--maps'" )
276+ parser .add_argument ("--include-default-builtin-maps" , action = 'store_true' ,
277+ help = "[visualiser|analyzer|trainer] include default builtin maps (does not include all cached maps) even when a custom list is provided via '--maps'" )
278+ parser .add_argument ("--list-maps" , action = "store_true" , help = "[visualiser|analyzer|trainer] output list of available built-in maps" )
279+
231280 parser .add_argument ("--deterministic" , action = 'store_true' , help = "use pre-defined random seeds for deterministic exeuction" )
232281 parser .add_argument ("--std-random-seed" , type = int , default = 0 , help = "'random' module random number generator seed" )
233282 parser .add_argument ("--numpy-random-seed" , type = int , default = 0 , help = "'numpy' module random number generator seed" )
0 commit comments