Skip to content

Commit d4a4a90

Browse files
MapManager & command-line interface analogous to algorithm's one (#61)
* Added MapManager and switched to using that rather than Maps * Reworked map management Signed-off-by: [ 大鳄 ] Asew <[email protected]> * Checking for name when updating map Signed-off-by: [ 大鳄 ] Asew <[email protected]> Co-authored-by: [ 大鳄 ] Asew <[email protected]>
1 parent 0993e03 commit d4a4a90

File tree

22 files changed

+538
-393
lines changed

22 files changed

+538
-393
lines changed

src/algorithms/algorithm_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from algorithms.algorithm import Algorithm
22
from utility.compatibility import HAS_OMPL
3+
from utility.misc import static_class
34

45
from typing import Optional, List, Type, Dict, Any, Tuple
56
import importlib.util
@@ -75,11 +76,6 @@
7576
from algorithms.classic.sample_based.ompl_stride import OMPL_STRIDE
7677
from algorithms.classic.sample_based.ompl_qrrt import OMPL_QRRT
7778

78-
def static_class(cls):
79-
if getattr(cls, "_static_init_", None):
80-
cls._static_init_()
81-
return cls
82-
8379
@static_class
8480
class AlgorithmManager():
8581
MetaData = Tuple[Type[Algorithm], Type[BasicTesting], Tuple[List[Any], Dict[str, Any]]]

src/algorithms/configuration/configuration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def __init__(self) -> None:
112112
from algorithms.algorithm_manager import AlgorithmManager
113113
self.algorithms = copy.deepcopy(AlgorithmManager.builtins)
114114

115+
from maps.map_manager import MapManager
116+
self.maps = copy.deepcopy(MapManager.builtins)
117+
115118
self.map_name = None
116119
self.algorithm_name = None
117120

src/algorithms/configuration/maps/dense_map.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class DenseMap(Map):
2828
# The transpose flag is set to true as a default for the initialization, as that is
2929
# how we are storing internally, but it will be set to false when we are simply translating from sparsemap
3030
# When we create more map views, we should set it to false
31-
def __init__(self, grid: Optional[List] = None, services: Services = None, transpose: bool = True, mutable: bool = False) -> None:
31+
def __init__(self, grid: Optional[List] = None, services: Services = None, transpose: bool = True, mutable: bool = False, name: Optional[str] = None) -> None:
3232
self.grid = None
3333

3434
arr_grid = None
@@ -37,9 +37,9 @@ def __init__(self, grid: Optional[List] = None, services: Services = None, trans
3737
if arr_grid.dtype == object:
3838
raise ValueError("Cannot create DenseMap grid from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes)")
3939

40-
super().__init__(Size(*([0]*arr_grid.ndim)), services, mutable)
40+
super().__init__(Size(*([0]*arr_grid.ndim)), services, mutable, name)
4141
else:
42-
super().__init__(services=services, mutable=mutable)
42+
super().__init__(services=services, mutable=mutable, name=name)
4343
return
4444

4545
# Doesn't work with non-uniform grids
@@ -186,6 +186,7 @@ def __copy__(self) -> 'DenseMap':
186186

187187
def __deepcopy__(self, memo: Dict) -> 'DenseMap':
188188
dense_map = self.__class__(copy.deepcopy(self.grid), services=self.services, transpose=False)
189+
dense_map.name = copy.deepcopy(self.name)
189190
dense_map.trace = copy.deepcopy(self.trace)
190191
dense_map.agent = copy.deepcopy(self.agent)
191192
dense_map.goal = copy.deepcopy(self.goal)

src/algorithms/configuration/maps/map.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Map:
3030
EXTENDED_WALL_ID: int = 4
3131
UNMAPPED_ID: int = 5
3232

33+
name: Optional[str]
3334
agent: Agent
3435
goal: Goal
3536
obstacles: List[Obstacle]
@@ -93,7 +94,7 @@ def init_direction_vectors(self):
9394
self.DIRECT_POINTS_MOVE_VECTOR: List[Point] = \
9495
list(map(lambda x: Point(*x), ALL_DIRECT_POINTS_DIMENSIONS))
9596

96-
def __init__(self, size: Size = None, services: Services = None, mutable: bool = False) -> None:
97+
def __init__(self, size: Size = None, services: Services = None, mutable: bool = False, name: Optional[str] = None) -> None:
9798
"""
9899
:param size: The map size
99100
:param services: The simulator services
@@ -106,6 +107,7 @@ def __init__(self, size: Size = None, services: Services = None, mutable: bool =
106107
self._size = None
107108
self.size = size
108109
self.__mutable = mutable
110+
self.name = name
109111

110112
def get_obstacle_bound(self, obstacle_start_point: Point, visited: Optional[Set[Point]] = None) -> Set[Point]:
111113
"""

src/algorithms/configuration/maps/occupancy_grid_map.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ def __init__(self,
3737
traversable_threshold: Optional[Real] = None,
3838
unmapped_value: Optional[Real] = None,
3939
services: Services = None,
40-
mutable: bool = True) -> None:
41-
super().__init__(services=services, mutable=mutable)
40+
mutable: bool = True,
41+
name: Optional[str] = None) -> None:
42+
super().__init__(services=services, mutable=mutable, name=name)
4243
self.agent = agent
4344
self.goal = goal
4445
self.weight_grid = None
@@ -200,7 +201,8 @@ def __deepcopy__(self, memo: Dict) -> 'OccupancyGridMap':
200201
mp = self.__class__(agent=copy.deepcopy(self.agent),
201202
goal=copy.deepcopy(self.goal),
202203
services=self.services,
203-
mutable=self.mutable)
204+
mutable=self.mutable,
205+
name=copy.deepcopy(self.name))
204206
mp.size = copy.deepcopy(self.size)
205207
mp.traversable_threshold = copy.deepcopy(self.traversable_threshold)
206208
mp.weight_grid = copy.deepcopy(self.weight_grid)

src/algorithms/configuration/maps/ros_map.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ def __init__(self, agent: Agent, goal: Goal,
2424
wp_publish: Optional[Callable[[Point], None]] = None,
2525
update_requested: Optional[Callable[[], None]] = None,
2626
services: Services = None,
27-
mutable: bool = True) -> None:
28-
super().__init__(agent=agent, goal=goal, services=services, mutable=mutable)
27+
mutable: bool = True,
28+
name: Optional[str] = None) -> None:
29+
super().__init__(agent=agent, goal=goal, services=services, mutable=mutable, name=name)
2930

3031
self.__get_grid = get_grid
3132
self.__wp_publish = wp_publish
@@ -53,7 +54,7 @@ def __deepcopy__(self, memo: Dict) -> 'RosMap':
5354
mp = RosMap(copy.deepcopy(self.agent), copy.deepcopy(self.goal),
5455
self.__get_grid, self.__weight_bounds, self.__traversable_threshold, self.__unmapped_value,
5556
self.__wp_publish, self.__update_requested,
56-
self.services, self.mutable)
57+
self.services, self.mutable, copy.deepcopy(self.name))
5758
mp.size = copy.deepcopy(self.size)
5859
mp.weight_grid = copy.deepcopy(self.weight_grid)
5960
mp.traversable_threshold = copy.deepcopy(self.traversable_threshold)

src/algorithms/configuration/maps/sparse_map.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ class SparseMap(Map):
2222

2323
DIST_TOLERANCE = 0.0001
2424

25-
def __init__(self, size: Size, agent: Agent, obstacles: List[Obstacle], goal: Goal, services: Services = None, mutable: bool = False) \
25+
def __init__(self, size: Size, agent: Agent, obstacles: List[Obstacle], goal: Goal, services: Services = None, mutable: bool = False, name: Optional[str] = None) \
2626
-> None:
27-
super().__init__(size, services, mutable)
27+
super().__init__(size, services, mutable, name)
2828
self.agent = agent
2929
self.obstacles = obstacles
3030
self.goal = goal
@@ -118,8 +118,13 @@ def __copy__(self) -> 'SparseMap':
118118
return copy.deepcopy(self)
119119

120120
def __deepcopy__(self, memo: Dict) -> 'SparseMap':
121-
dense_map = SparseMap(self.size, copy.deepcopy(self.agent),
122-
copy.deepcopy(self.obstacles), copy.deepcopy(self.goal), self.services)
121+
dense_map = SparseMap(self.size,
122+
copy.deepcopy(self.agent),
123+
copy.deepcopy(self.obstacles),
124+
copy.deepcopy(self.goal),
125+
self.services,
126+
self.mutable,
127+
copy.deepcopy(self.name))
123128
dense_map.trace = copy.deepcopy(self.trace)
124129
return dense_map
125130

src/analyzer/analyzer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from algorithms.basic_testing import BasicTesting
1212
from algorithms.algorithm import Algorithm
1313
from structures import Point
14-
from maps import Maps
1514

1615
from io import StringIO
1716
import seaborn as sns

src/main.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from algorithms.configuration.configuration import Configuration
22
from algorithms.algorithm_manager import AlgorithmManager
3+
from maps.map_manager import MapManager
34
from algorithms.lstm.trainer import Trainer
45
from analyzer.analyzer import Analyzer
56
from generator.generator import Generator
@@ -151,7 +152,7 @@ def configure_common(config, args) -> bool:
151152
print("Available algorithms:")
152153
for key in AlgorithmManager.builtins.keys():
153154
print(f" {key}")
154-
print("Or specify your own file with a class that inherits from Algorithm")
155+
print("Or specify your own file that contains a class that inherits from Algorithm")
155156
sys.exit(0)
156157

157158
if args.algorithms:
@@ -162,7 +163,7 @@ def configure_common(config, args) -> bool:
162163
valid_str = ",".join('"' + a + '"' for a in AlgorithmManager.builtins.keys())
163164
print(f"Invalid algorithm(s) specified: {invalid_str}", file=sys.stderr)
164165
print(f"Available algorithms: {valid_str}", file=sys.stderr)
165-
print("Or specify your own file with a class that inherits from Algorithm", file=sys.stderr)
166+
print("Or specify your own file that contains a class that inherits from Algorithm", file=sys.stderr)
166167
return False
167168

168169
algorithms = list(flatten(algorithms, depth=1))
@@ -179,6 +180,46 @@ def configure_common(config, args) -> bool:
179180

180181
config.algorithms = algorithms
181182

183+
if args.list_maps:
184+
print("Available maps:")
185+
for key in MapManager.builtins.keys():
186+
print(f" {key}")
187+
print("Can also specify a custom map,")
188+
print(" (1) cached map stored in Maps")
189+
print(" (2) external file that contains a global variable with type that inherits from Map")
190+
sys.exit(0)
191+
192+
if args.maps:
193+
maps = MapManager.load_all(args.maps)
194+
if not all(maps):
195+
invalid_maps = [args.maps[i] for i in range(len(maps)) if not maps[i]]
196+
invalid_str = ",".join('"' + a + '"' for a in invalid_maps)
197+
valid_str = ",".join('"' + a + '"' for a in MapManager.builtins.keys())
198+
print(f"Invalid map(s) specified: {invalid_str}", file=sys.stderr)
199+
print(f"Available maps: {valid_str}", file=sys.stderr)
200+
print("Can also specify a custom map,", file=sys.stderr)
201+
print(" (1) cached map stored in Maps", file=sys.stderr)
202+
print(" (2) external file that contains a global variable with type that inherits from Map", file=sys.stderr)
203+
return False
204+
205+
maps = list(flatten(maps, depth=1))
206+
207+
# name uniqueness
208+
names = [a[0] for a in maps]
209+
if len(set(names)) != len(names):
210+
print("Name conflict detected in custom map list:", names, file=sys.stderr)
211+
return False
212+
213+
maps = dict(maps)
214+
if args.include_default_builtin_maps or args.include_all_builtin_maps:
215+
maps.update(MapManager.builtins)
216+
if args.include_all_builtin_maps:
217+
maps.update(MapManager.cached_builtins)
218+
219+
config.maps = maps
220+
elif args.include_all_builtin_maps:
221+
config.maps.update(MapManager.cached_builtins)
222+
182223
if args.deterministic:
183224
random.seed(args.std_random_seed)
184225
torch.manual_seed(args.torch_random_seed)
@@ -225,9 +266,17 @@ def main() -> bool:
225266
parser.add_argument("--dims", type=int, help="[generator|analyzer] number of dimensions", default=3)
226267

227268
parser.add_argument("--algorithms", help="[visualiser|analyzer] algorithms to load (either built-in algorithm name or module file path)", nargs="+")
228-
parser.add_argument("--include-builtin-algorithms", action='store_true', help="include all builtin algorithms even when a custom list is provided via '--algorithms'")
269+
parser.add_argument("--include-builtin-algorithms", action='store_true',
270+
help="[visualiser|analyzer] include all builtin algorithms even when a custom list is provided via '--algorithms'")
229271
parser.add_argument("--list-algorithms", action="store_true", help="[visualiser|analyzer] output list of available built-in algorithms")
230272

273+
parser.add_argument("--maps", help="[visualiser|analyzer|trainer] maps to load (either built-in map name or module file path)", nargs="+")
274+
parser.add_argument("--include-all-builtin-maps", action='store_true',
275+
help="[visualiser|analyzer|trainer] include all builtin maps (includes all cached maps) even when a custom list is provided via '--maps'")
276+
parser.add_argument("--include-default-builtin-maps", action='store_true',
277+
help="[visualiser|analyzer|trainer] include default builtin maps (does not include all cached maps) even when a custom list is provided via '--maps'")
278+
parser.add_argument("--list-maps", action="store_true", help="[visualiser|analyzer|trainer] output list of available built-in maps")
279+
231280
parser.add_argument("--deterministic", action='store_true', help="use pre-defined random seeds for deterministic exeuction")
232281
parser.add_argument("--std-random-seed", type=int, default=0, help="'random' module random number generator seed")
233282
parser.add_argument("--numpy-random-seed", type=int, default=0, help="'numpy' module random number generator seed")

0 commit comments

Comments
 (0)