Skip to content

Commit 81a75da

Browse files
Added command line arguments to analyzer (#51)
* Adding cmd options for analyser Able to import custom modules inheriting from Algorithm Added docs in main Fixing rebase problems Added more generator options * Reworked main.py & polished code base Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Added AlgorithmManager Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Wired analyzer to use AlgorithmManager Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Polished main.py Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Bug fix Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Polished test code + arg parsing Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Cleanup and import error fix Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Import path cleanup Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Polished command line interface Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Attempt at fixing CI error Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Moved constants file to resolve import error Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Tweaked paths Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Implemented multiple algorithms per imported file Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Name conflict detection Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Removed stale files Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Polished error handling Signed-off-by: [ 大鳄 ] Asew <[email protected]> Co-authored-by: [ 大鳄 ] Asew <[email protected]>
1 parent 8ea33f4 commit 81a75da

26 files changed

+438
-10679
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
torch
22
numpy
3+
matplotlib
34
nptyping
45
panda3d
56
pandas
67
sklearn
7-
matplotlib
88
torchvision
99
memory_profiler
1010
seaborn
1111
dill
1212
natsort
1313
screeninfo
1414
opencv-python
15-
lru-dict
15+
lru-dict
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from algorithms.algorithm import Algorithm
2+
from utility.compatibility import HAS_OMPL
3+
4+
from typing import Optional, List, Type, Dict, Any, Tuple
5+
import importlib.util
6+
import inspect
7+
import os
8+
import sys
9+
import copy
10+
import traceback
11+
12+
# planner testing
13+
from algorithms.basic_testing import BasicTesting
14+
from algorithms.classic.testing.a_star_testing import AStarTesting
15+
from algorithms.classic.testing.combined_online_lstm_testing import CombinedOnlineLSTMTesting
16+
from algorithms.classic.testing.dijkstra_testing import DijkstraTesting
17+
from algorithms.classic.testing.wavefront_testing import WavefrontTesting
18+
from algorithms.classic.testing.way_point_navigation_testing import WayPointNavigationTesting
19+
20+
# planner implementations
21+
from algorithms.classic.graph_based.a_star import AStar
22+
from algorithms.classic.graph_based.bug1 import Bug1
23+
from algorithms.classic.graph_based.bug2 import Bug2
24+
from algorithms.classic.graph_based.dijkstra import Dijkstra
25+
from algorithms.classic.graph_based.potential_field import PotentialField
26+
from algorithms.classic.sample_based.sprm import SPRM
27+
from algorithms.classic.sample_based.rt import RT
28+
from algorithms.classic.sample_based.rrt import RRT
29+
from algorithms.classic.sample_based.rrt_star import RRT_Star
30+
from algorithms.classic.sample_based.rrt_connect import RRT_Connect
31+
from algorithms.classic.graph_based.wavefront import Wavefront
32+
from algorithms.lstm.LSTM_tile_by_tile import OnlineLSTM
33+
from algorithms.lstm.a_star_waypoint import WayPointNavigation
34+
from algorithms.lstm.combined_online_LSTM import CombinedOnlineLSTM
35+
36+
if HAS_OMPL:
37+
from algorithms.classic.sample_based.ompl_rrt import OMPL_RRT
38+
from algorithms.classic.sample_based.ompl_prmstar import OMPL_PRMstar
39+
from algorithms.classic.sample_based.ompl_lazyprmstar import OMPL_LazyPRMstar
40+
from algorithms.classic.sample_based.ompl_rrtstar import OMPL_RRTstar
41+
from algorithms.classic.sample_based.ompl_rrtsharp import OMPL_RRTsharp
42+
from algorithms.classic.sample_based.ompl_rrtx import OMPL_RRTXstatic
43+
from algorithms.classic.sample_based.ompl_informedrrt import OMPL_InformedRRT
44+
from algorithms.classic.sample_based.ompl_kpiece1 import OMPL_KPIECE1
45+
from algorithms.classic.sample_based.ompl_ltlplanner import OMPL_LTLPlanner
46+
from algorithms.classic.sample_based.ompl_pdst import OMPL_PDST
47+
from algorithms.classic.sample_based.ompl_sst import OMPL_SST
48+
from algorithms.classic.sample_based.ompl_aitstar import OMPL_AITstar
49+
from algorithms.classic.sample_based.ompl_anytimepathshortening import OMPL_AnytimePathShortening
50+
from algorithms.classic.sample_based.ompl_bfmt import OMPL_BFMT
51+
from algorithms.classic.sample_based.ompl_biest import OMPL_BiEST
52+
from algorithms.classic.sample_based.ompl_rrtconnect import OMPL_RRTConnect
53+
from algorithms.classic.sample_based.ompl_trrt import OMPL_TRRT
54+
from algorithms.classic.sample_based.ompl_birlrt import OMPL_BiRLRT
55+
from algorithms.classic.sample_based.ompl_bitrrt import OMPL_BiTRRT
56+
from algorithms.classic.sample_based.ompl_bitstar import OMPL_BITstar
57+
from algorithms.classic.sample_based.ompl_bkpiece1 import OMPL_BKPIECE1
58+
from algorithms.classic.sample_based.ompl_syclop import OMPL_Syclop
59+
from algorithms.classic.sample_based.ompl_cforest import OMPL_CForest
60+
from algorithms.classic.sample_based.ompl_est import OMPL_EST
61+
from algorithms.classic.sample_based.ompl_fmt import OMPL_FMT
62+
from algorithms.classic.sample_based.ompl_lazylbtrrt import OMPL_LazyLBTRRT
63+
from algorithms.classic.sample_based.ompl_lazyprm import OMPL_LazyPRM
64+
from algorithms.classic.sample_based.ompl_lazyrrt import OMPL_LazyRRT
65+
from algorithms.classic.sample_based.ompl_lbkpiece1 import OMPL_LBKPIECE1
66+
from algorithms.classic.sample_based.ompl_lbtrrt import OMPL_LBTRRT
67+
from algorithms.classic.sample_based.ompl_prm import OMPL_PRM
68+
from algorithms.classic.sample_based.ompl_spars import OMPL_SPARS
69+
from algorithms.classic.sample_based.ompl_spars2 import OMPL_SPARS2
70+
from algorithms.classic.sample_based.ompl_vfrrt import OMPL_VFRRT
71+
from algorithms.classic.sample_based.ompl_prrt import OMPL_pRRT
72+
from algorithms.classic.sample_based.ompl_tsrrt import OMPL_TSRRT
73+
from algorithms.classic.sample_based.ompl_psbl import OMPL_pSBL
74+
from algorithms.classic.sample_based.ompl_sbl import OMPL_SBL
75+
from algorithms.classic.sample_based.ompl_stride import OMPL_STRIDE
76+
from algorithms.classic.sample_based.ompl_qrrt import OMPL_QRRT
77+
78+
def static_class(cls):
79+
if getattr(cls, "_static_init_", None):
80+
cls._static_init_()
81+
return cls
82+
83+
@static_class
84+
class AlgorithmManager():
85+
MetaData = Tuple[Type[Algorithm], Type[BasicTesting], Tuple[List[Any], Dict[str, Any]]]
86+
87+
builtins: Dict[str, MetaData]
88+
89+
@classmethod
90+
def _static_init_(cls):
91+
cls.builtins = {
92+
"A*": (AStar, AStarTesting, ([], {})),
93+
"Global Way-point LSTM": (WayPointNavigation, WayPointNavigationTesting, (
94+
[], {"global_kernel": (CombinedOnlineLSTM, ([], {})), "global_kernel_max_it": 100})),
95+
"LSTM Bagging": (CombinedOnlineLSTM, CombinedOnlineLSTMTesting, ([], {})),
96+
"CAE Online LSTM": (
97+
OnlineLSTM, BasicTesting, ([], {"load_name": "caelstm_section_cae_training_house_100_model"})),
98+
"Online LSTM": (OnlineLSTM, BasicTesting, (
99+
[],
100+
{"load_name": "tile_by_tile_training_uniform_random_fill_3000_block_map_3000_house_3000_model"})),
101+
"SPRM": (SPRM, BasicTesting, ([], {})),
102+
"RT": (RT, BasicTesting, ([], {})),
103+
"RRT": (RRT, BasicTesting, ([], {})),
104+
"RRT*": (RRT_Star, BasicTesting, ([], {})),
105+
"RRT-Connect": (RRT_Connect, BasicTesting, ([], {})),
106+
"Wave-front": (Wavefront, WavefrontTesting, ([], {})),
107+
"Dijkstra": (Dijkstra, DijkstraTesting, ([], {})),
108+
"Bug1": (Bug1, BasicTesting, ([], {})),
109+
"Bug2": (Bug2, BasicTesting, ([], {})),
110+
"Potential Field": (PotentialField, BasicTesting, ([], {}))
111+
}
112+
113+
if HAS_OMPL:
114+
cls.builtins.update({
115+
"OMPL RRT": (OMPL_RRT, BasicTesting, ([], {})),
116+
"OMPL PRM*": (OMPL_PRMstar, BasicTesting, ([], {})),
117+
"OMPL Lazy PRM*": (OMPL_LazyPRMstar, BasicTesting, ([], {})),
118+
"OMPL RRT*": (OMPL_RRTstar, BasicTesting, ([], {})),
119+
"OMPL RRT#": (OMPL_RRTsharp, BasicTesting, ([], {})),
120+
"OMPL RRTX": (OMPL_RRTXstatic, BasicTesting, ([], {})),
121+
"OMPL KPIECE1": (OMPL_KPIECE1, BasicTesting, ([], {})),
122+
"OMPL LazyLBTRRT": (OMPL_LazyLBTRRT, BasicTesting, ([], {})),
123+
"OMPL LazyPRM": (OMPL_LazyPRM, BasicTesting, ([], {})),
124+
"OMPL LazyRRT": (OMPL_LazyRRT, BasicTesting, ([], {})),
125+
"OMPL LBKPIECE1": (OMPL_LBKPIECE1, BasicTesting, ([], {})),
126+
"OMPL LBTRRT": (OMPL_LBTRRT, BasicTesting, ([], {})),
127+
"OMPL PRM": (OMPL_PRM, BasicTesting, ([], {})),
128+
"OMPL SBL": (OMPL_SBL, BasicTesting, ([], {})),
129+
"OMPL STRIDE": (OMPL_STRIDE, BasicTesting, ([], {})),
130+
"OMPL PDST": (OMPL_PDST, BasicTesting, ([], {})),
131+
"OMPL SST": (OMPL_SST, BasicTesting, ([], {})),
132+
"OMPL BiEst": (OMPL_BiEST, BasicTesting, ([], {})),
133+
"OMPL TRRT": (OMPL_TRRT, BasicTesting, ([], {})),
134+
"OMPL RRTConnect": (OMPL_RRTConnect, BasicTesting, ([], {})),
135+
"OMPL BITstar": (OMPL_BITstar, BasicTesting, ([], {})),
136+
"OMPL BKPIECE1": (OMPL_BKPIECE1, BasicTesting, ([], {})),
137+
"OMPL EST": (OMPL_EST, BasicTesting, ([], {})),
138+
# "OMPL LTLPlanner": (OMPL_LTLPlanner, BasicTesting, ([], {})),
139+
# "OMPL AITstar": (OMPL_AITstar, BasicTesting, ([], {})),
140+
# "OMPL AnytimePathShortening": (OMPL_AnytimePathShortening, BasicTesting, ([], {})),
141+
# "OMPL BFMT": (OMPL_BFMT, BasicTesting, ([], {})),
142+
# "OMPL BiRLRT": (OMPL_BiRLRT, BasicTesting, ([], {})),
143+
# "OMPL BiTRRT": (OMPL_BiTRRT, BasicTesting, ([], {})),
144+
# "OMPL Syclop ": (OMPL_Syclop, BasicTesting, ([], {})),
145+
# "OMPL CForest": (OMPL_CForest, BasicTesting, ([], {})),
146+
# "OMPL FMT": (OMPL_FMT, BasicTesting, ([], {})),
147+
# "OMPL SPARS": (OMPL_SPARS, BasicTesting, ([], {})),
148+
# "OMPL SPARS2": (OMPL_SPARS2, BasicTesting, ([], {})),
149+
# "OMPL VFRRT": (OMPL_VFRRT, BasicTesting, ([], {})),
150+
# "OMPL pRRT": (OMPL_pRRT, BasicTesting, ([], {})),
151+
# "OMPL TSRRT": (OMPL_TSRRT, BasicTesting, ([], {})),
152+
# "OMPL pSBL": (OMPL_pSBL, BasicTesting, ([], {})),
153+
# "OMPL QRRT": (OMPL_QRRT, BasicTesting, ([], {})),
154+
})
155+
156+
@staticmethod
157+
def load_all(ids: List[str]) -> List[List[Tuple[str, MetaData]]]:
158+
"""
159+
Returns a list of algorithms from a list of names or file paths.
160+
161+
For each element in `ids`, if string is the display name
162+
of a built-in algorithm, then we return that algorithm. Otherwise,
163+
we return the result of AlgorithmManager.try_load_from_file().
164+
"""
165+
166+
algs: List[List[Tuple[str, MetaData]]] = []
167+
for alg in ids:
168+
if alg in AlgorithmManager.builtins:
169+
algs.append([copy.deepcopy((alg, AlgorithmManager.builtins[alg]))])
170+
else:
171+
algs.append(AlgorithmManager.try_load_from_file(alg))
172+
return algs
173+
174+
@staticmethod
175+
def try_load_from_file(path: str) -> List[Tuple[str, MetaData]]:
176+
if not os.path.exists(path):
177+
msg = "File '{}' does not exist".format(path)
178+
print(msg, file=sys.stderr)
179+
return []
180+
181+
try:
182+
spec = importlib.util.spec_from_file_location("custom_loaded", path)
183+
module = importlib.util.module_from_spec(spec)
184+
spec.loader.exec_module(module)
185+
186+
# return all classes that inherit from "Algorithm"
187+
algs = []
188+
for name in dir(module):
189+
if name.startswith("_"):
190+
continue
191+
192+
cls = getattr(module, name)
193+
if inspect.isclass(cls) and cls is not Algorithm and issubclass(cls, Algorithm):
194+
name = cls.name if "name" in cls.__dict__ else os.path.basename(path) + " ({})".format(name)
195+
testing = cls.testing if "testing" in cls.__dict__ else BasicTesting
196+
algs.append((name, (cls, testing, ([], {}))))
197+
return algs
198+
except:
199+
msg = "Failed to load algorithms from file '{}', reason:\n{}".format(path, traceback.format_exc())
200+
print(msg, file=sys.stderr)
201+
return []

src/algorithms/configuration/configuration.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Tuple, Callable, Type, List, Optional, Dict, Any, Union
2+
import copy
23

34
from algorithms.algorithm import Algorithm
45
from algorithms.basic_testing import BasicTesting
@@ -38,6 +39,10 @@ class Configuration:
3839
generator_show_gen_sample: bool
3940
generator_house_expo: bool
4041
generator_size: int
42+
generator_obstacle_fill_min: float
43+
generator_obstacle_fill_max: float
44+
generator_min_room_size: int
45+
generator_max_room_size: int
4146

4247
# Trainer
4348
trainer: bool
@@ -47,6 +52,7 @@ class Configuration:
4752

4853
# Misc
4954
analyzer: bool
55+
algorithms: Dict[str, Tuple[Type[Algorithm], Type[BasicTesting], Tuple[List[Any], Dict[str, Any]]]]
5056
load_simulator: bool
5157
clear_cache: bool
5258
num_dim: int
@@ -82,6 +88,10 @@ def __init__(self) -> None:
8288
self.generator_modify = None
8389
self.generator_show_gen_sample = False
8490
self.generator_house_expo = False
91+
self.generator_obstacle_fill_min = 0.1
92+
self.generator_obstacle_fill_max = 0.3
93+
self.generator_min_room_size = 3
94+
self.generator_max_room_size = 16
8595
self.generator_size = 64
8696

8797
self.num_dim = 2
@@ -93,9 +103,13 @@ def __init__(self) -> None:
93103
self.trainer_pre_process_data_only = False
94104
self.trainer_bypass_and_replace_pre_processed_cache = False
95105

96-
# Custom behaviour settings
106+
# Analyzer
97107
self.analyzer = False
98108

109+
# Common
110+
from algorithms.algorithm_manager import AlgorithmManager
111+
self.algorithms = copy.deepcopy(AlgorithmManager.builtins)
112+
99113
# Simulator
100114
self.load_simulator = False
101115

src/algorithms/lstm/LSTM_CAE_tile_by_tile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from algorithms.lstm.ML_model import MLModel, EvaluationResults
2020
from algorithms.lstm.map_processing import MapProcessing
2121
from simulator.services.services import Services
22-
from constants import DATA_PATH
22+
from utility.constants import DATA_PATH
2323

2424

2525
class CAEEncoder(nn.Module):

src/analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ def main():
1515
analyzer.analyze_algorithms()
1616

1717
if __name__== '__main__':
18-
main()
18+
main()

0 commit comments

Comments
 (0)