Skip to content

Commit 7eff87c

Browse files
committed
[BugFix] Soccer
1 parent acd9b7a commit 7eff87c

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

vmas/scenarios/football.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2233,16 +2233,27 @@ def get_pos_value(self, pos, agent, env_index=Ellipsis):
22332233
return value
22342234

22352235
def get_wall_separations(self, pos):
2236-
top_wall_dist = -pos[:, Y] + self.world.pitch_width / 2
2237-
bottom_wall_dist = pos[:, Y] + self.world.pitch_width / 2
2238-
left_wall_dist = pos[:, X] + self.world.pitch_length / 2
2239-
right_wall_dist = -pos[:, X] + self.world.pitch_length / 2
2236+
top_wall_dist = -pos[..., Y] + self.world.pitch_width / 2
2237+
bottom_wall_dist = pos[..., Y] + self.world.pitch_width / 2
2238+
left_wall_dist = pos[..., X] + self.world.pitch_length / 2
2239+
right_wall_dist = -pos[..., X] + self.world.pitch_length / 2
22402240
vertical_wall_disp = torch.zeros(pos.shape, device=self.world.device)
2241-
vertical_wall_disp[:, Y] = torch.minimum(top_wall_dist, bottom_wall_dist)
2242-
vertical_wall_disp[bottom_wall_dist < top_wall_dist, Y] *= -1
2241+
vertical_wall_disp[..., Y] = torch.minimum(top_wall_dist, bottom_wall_dist)
22432242
horizontal_wall_disp = torch.zeros(pos.shape, device=self.world.device)
2244-
horizontal_wall_disp[:, X] = torch.minimum(left_wall_dist, right_wall_dist)
2245-
horizontal_wall_disp[left_wall_dist < right_wall_dist, X] *= -1
2243+
horizontal_wall_disp[..., X] = torch.minimum(left_wall_dist, right_wall_dist)
2244+
2245+
shape = vertical_wall_disp.shape
2246+
vertical_wall_disp = vertical_wall_disp.view(shape[0] * shape[1], 2)
2247+
mask = (bottom_wall_dist < top_wall_dist).view(shape[0] * shape[1])
2248+
vertical_wall_disp[mask, Y] *= -1
2249+
vertical_wall_disp = vertical_wall_disp.view(*shape)
2250+
2251+
shape = horizontal_wall_disp.shape
2252+
horizontal_wall_disp = horizontal_wall_disp.view(shape[0] * shape[1], 2)
2253+
mask = (left_wall_dist < right_wall_dist).view(shape[0] * shape[1])
2254+
horizontal_wall_disp[mask, X] *= -1
2255+
horizontal_wall_disp = horizontal_wall_disp.view(*shape)
2256+
22462257
return torch.stack([vertical_wall_disp, horizontal_wall_disp], dim=-2)
22472258

22482259
def get_separations(

0 commit comments

Comments
 (0)