Skip to content

Commit 132d97b

Browse files
[Feature] Boundary visualization for limited-size environments (#142)
* add boundary visualization for limited-size environments * add boundary visualization for limited-size environments * add boundary visualization for limited-size environments * introduce "visualize_semidims" to display boundaries * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * Update vmas/simulator/scenario.py Co-authored-by: Matteo Bettini <[email protected]> * add boundary visualization for limited-size environments * disabled "visualize_semidims" as boundaries are already being plotted in this scenario * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * disabled "visualize_semidims" as boundaries are already being plotted in this scenario * add boundary visualization for limited-size environments * Update vmas/simulator/environment/environment.py Co-authored-by: Matteo Bettini <[email protected]> * add boundary visualization for limited-size environments --------- Co-authored-by: Matteo Bettini <[email protected]>
1 parent 26ceb42 commit 132d97b

File tree

11 files changed

+77
-0
lines changed

11 files changed

+77
-0
lines changed

vmas/scenarios/balance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2626
self.shaping_factor = 100
2727
self.fall_reward = -10
2828

29+
self.visualize_semidims = False
30+
2931
# Make world
3032
world = World(batch_dim, device, gravity=(0.0, -0.05), y_semidim=1)
3133
# Add agents

vmas/scenarios/ball_passage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
3333
self.passage_width = 0.2
3434
self.passage_length = 0.103
3535

36+
self.visualize_semidims = False
37+
3638
# Make world
3739
world = World(
3840
batch_dim,

vmas/scenarios/football.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class Scenario(BaseScenario):
1818
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
1919
self.init_params(**kwargs)
20+
self.visualize_semidims = False
2021
world = self.init_world(batch_dim, device)
2122
self.init_agents(world)
2223
self.init_ball(world)

vmas/scenarios/joint_passage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
6565
ScenarioUtils.check_kwargs_consumed(kwargs)
6666

6767
self.plot_grid = True
68+
self.visualize_semidims = False
69+
6870
# Make world
6971
world = World(
7072
batch_dim,

vmas/scenarios/joint_passage_size.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
7373
assert self.n_passages == 3 or self.n_passages == 4
7474

7575
self.plot_grid = False
76+
self.visualize_semidims = False
7677

7778
# Make world
7879
world = World(

vmas/scenarios/mpe/simple_tag.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2626
self.respawn_at_catch = kwargs.pop("respawn_at_catch", False)
2727
ScenarioUtils.check_kwargs_consumed(kwargs)
2828

29+
self.visualize_semidims = False
30+
2931
world = World(
3032
batch_dim=batch_dim,
3133
device=device,

vmas/scenarios/passage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2626
self.passage_width = 0.2
2727
self.passage_length = 0.103
2828

29+
self.visualize_semidims = False
30+
2931
# Make world
3032
world = World(batch_dim, device, x_semidim=1, y_semidim=1)
3133
# Add agents

vmas/scenarios/road_traffic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Scenario(BaseScenario):
5757

5858
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
5959
self.init_params(batch_dim, device, **kwargs)
60+
self.visualize_semidims = False
6061
world = self.init_world(batch_dim, device)
6162
self.init_agents(world)
6263
return world

vmas/scenarios/sampling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
4343
assert len(self.covs) == self.n_gaussians
4444

4545
self.plot_grid = False
46+
self.visualize_semidims = False
4647
self.n_x_cells = int((2 * self.xdim) / self.grid_spacing)
4748
self.n_y_cells = int((2 * self.ydim) / self.grid_spacing)
4849
self.max_pdf = torch.zeros((batch_dim,), device=device, dtype=torch.float32)

vmas/simulator/environment/environment.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,9 @@ def render(
741741
)
742742

743743
# Render
744+
if self.scenario.visualize_semidims:
745+
self.plot_boundary()
746+
744747
self._set_agent_comm_messages(env_index)
745748

746749
if plot_position_function is not None:
@@ -770,6 +773,64 @@ def render(
770773
# render to display or array
771774
return self.viewer.render(return_rgb_array=mode == "rgb_array")
772775

776+
def plot_boundary(self):
777+
# include boundaries in the rendering if the environment is dimension-limited
778+
if self.world.x_semidim is not None or self.world.y_semidim is not None:
779+
from vmas.simulator.rendering import Line
780+
from vmas.simulator.utils import Color
781+
782+
# set a big value for the cases where the environment is dimension-limited only in one coordinate
783+
infinite_value = 100
784+
785+
x_semi = (
786+
self.world.x_semidim
787+
if self.world.x_semidim is not None
788+
else infinite_value
789+
)
790+
y_semi = (
791+
self.world.y_semidim
792+
if self.world.y_semidim is not None
793+
else infinite_value
794+
)
795+
796+
# set the color for the boundary line
797+
color = Color.GRAY.value
798+
799+
# Define boundary points based on whether world semidims are provided
800+
if (
801+
self.world.x_semidim is not None and self.world.y_semidim is not None
802+
) or self.world.y_semidim is not None:
803+
boundary_points = [
804+
(-x_semi, y_semi),
805+
(x_semi, y_semi),
806+
(x_semi, -y_semi),
807+
(-x_semi, -y_semi),
808+
]
809+
else:
810+
boundary_points = [
811+
(-x_semi, y_semi),
812+
(-x_semi, -y_semi),
813+
(x_semi, y_semi),
814+
(x_semi, -y_semi),
815+
]
816+
817+
# Create lines by connecting points
818+
for i in range(
819+
0,
820+
len(boundary_points),
821+
1
822+
if (
823+
self.world.x_semidim is not None
824+
and self.world.y_semidim is not None
825+
)
826+
else 2,
827+
):
828+
start = boundary_points[i]
829+
end = boundary_points[(i + 1) % len(boundary_points)]
830+
line = Line(start, end, width=0.7)
831+
line.set_color(*color)
832+
self.viewer.add_onetime(line)
833+
773834
def plot_function(
774835
self, f, precision, plot_range, cmap_range, cmap_alpha, cmap_name
775836
):

0 commit comments

Comments
 (0)