@@ -2233,16 +2233,27 @@ def get_pos_value(self, pos, agent, env_index=Ellipsis):
22332233 return value
22342234
22352235 def get_wall_separations (self , pos ):
2236- top_wall_dist = - pos [: , Y ] + self .world .pitch_width / 2
2237- bottom_wall_dist = pos [: , Y ] + self .world .pitch_width / 2
2238- left_wall_dist = pos [: , X ] + self .world .pitch_length / 2
2239- right_wall_dist = - pos [: , X ] + self .world .pitch_length / 2
2236+ top_wall_dist = - pos [... , Y ] + self .world .pitch_width / 2
2237+ bottom_wall_dist = pos [... , Y ] + self .world .pitch_width / 2
2238+ left_wall_dist = pos [... , X ] + self .world .pitch_length / 2
2239+ right_wall_dist = - pos [... , X ] + self .world .pitch_length / 2
22402240 vertical_wall_disp = torch .zeros (pos .shape , device = self .world .device )
2241- vertical_wall_disp [:, Y ] = torch .minimum (top_wall_dist , bottom_wall_dist )
2242- vertical_wall_disp [bottom_wall_dist < top_wall_dist , Y ] *= - 1
2241+ vertical_wall_disp [..., Y ] = torch .minimum (top_wall_dist , bottom_wall_dist )
22432242 horizontal_wall_disp = torch .zeros (pos .shape , device = self .world .device )
2244- horizontal_wall_disp [:, X ] = torch .minimum (left_wall_dist , right_wall_dist )
2245- horizontal_wall_disp [left_wall_dist < right_wall_dist , X ] *= - 1
2243+ horizontal_wall_disp [..., X ] = torch .minimum (left_wall_dist , right_wall_dist )
2244+
2245+ shape = vertical_wall_disp .shape
2246+ vertical_wall_disp = vertical_wall_disp .view (shape [0 ] * shape [1 ], 2 )
2247+ mask = (bottom_wall_dist < top_wall_dist ).view (shape [0 ] * shape [1 ])
2248+ vertical_wall_disp [mask , Y ] *= - 1
2249+ vertical_wall_disp = vertical_wall_disp .view (* shape )
2250+
2251+ shape = horizontal_wall_disp .shape
2252+ horizontal_wall_disp = horizontal_wall_disp .view (shape [0 ] * shape [1 ], 2 )
2253+ mask = (left_wall_dist < right_wall_dist ).view (shape [0 ] * shape [1 ])
2254+ horizontal_wall_disp [mask , X ] *= - 1
2255+ horizontal_wall_disp = horizontal_wall_disp .view (* shape )
2256+
22462257 return torch .stack ([vertical_wall_disp , horizontal_wall_disp ], dim = - 2 )
22472258
22482259 def get_separations (
0 commit comments