Skip to content

Commit f8c157e

Browse files
committed
Edited _check_and_handle_wall_collisions to prevent escape near corners
1 parent 5ce2a27 commit f8c157e

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

ratinabox/Agent.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -423,22 +423,23 @@ def _wall_velocity_update(self, **kwargs):
423423
def _check_and_handle_wall_collisions(self):
424424
"""This function checks to see if the vector from self.prev_pos to self.pos collides with any walls. If it does, then you've nothing to worry about. If it does, then you need to bounce off the wall and update the velocity and position accordingly. This is done in the handle_wall_collisions() function.
425425
TODO strictly wall collisions are only considered in 2D but this function should be extended to 1D too, for completeness."""
426-
proposed_step = np.array([self.prev_pos, self.pos])
427-
wall_check = self.Environment.check_wall_collisions(proposed_step) #returns (None, None) for 1D Envs
428-
walls = wall_check[0] # shape=(N_walls,2,2)
429-
wall_collisions = wall_check[1] # shape=(N_walls,)
430-
431-
# If no wall collsions it is safe to move to the next position so do nothing
432-
if (wall_collisions is None) or (True not in wall_collisions): return
433-
434-
# Bounce off walls you collide with
435-
elif True in wall_collisions:
436-
colliding_wall = walls[np.argwhere(wall_collisions == True)[0][0]]
437-
self.velocity = utils.wall_bounce(self.velocity, colliding_wall)
438-
self.velocity = (0.5 * self.speed_mean / (np.linalg.norm(self.velocity))) * self.velocity
439-
# TODO strictly in the event of a collision the position should be updated away from the wall starting from the collision point (and only for the remaining fraction of dt), not the prev position. Small detail but worth fixing.
440-
self.pos = self.prev_pos + self.velocity * self.dt
441-
return
426+
while True: # keep checking until no wall collisions
427+
proposed_step = np.array([self.prev_pos, self.pos])
428+
wall_check = self.Environment.check_wall_collisions(proposed_step) #returns (None, None) for 1D Envs
429+
walls = wall_check[0] # shape=(N_walls,2,2)
430+
wall_collisions = wall_check[1] # shape=(N_walls,)
431+
432+
# If no wall collsions it is safe to move to the next position so do nothing
433+
if (wall_collisions is None) or (True not in wall_collisions): return
434+
435+
# Bounce off walls you collide with
436+
elif True in wall_collisions:
437+
colliding_wall = walls[np.argwhere(wall_collisions == True)[0][0]]
438+
self.velocity = utils.wall_bounce(self.velocity, colliding_wall)
439+
self.velocity = (0.5 * self.speed_mean / (np.linalg.norm(self.velocity))) * self.velocity
440+
# TODO strictly in the event of a collision the position should be updated away from the wall starting from the collision point (and only for the remaining fraction of dt), not the prev position. Small detail but worth fixing.
441+
self.pos = self.prev_pos + self.velocity * self.dt
442+
442443

443444
def _measure_velocity_of_step_taken(self, overwrite_velocity=False):
444445
"""This function takes self.prev_pos and self.pos and uses them to update self.measured_velocity. Then it takes self.prev_measured_velocity and self.measured_velocity and calculates self.measured_rotational_velocity. These "measured" velocities are typically the same as self.velocity and self.rotational_velocity but not always. The reason for this is that when the Agent is near a wall it is possible for the dynamical updates to adjust its position without adjusting its velocity (e.g. conveyor belt drift), in which case the absolute velocities of the agent (which are the one we want to save into the history dataframe) may be subtely different from the velocity used in the motion updates thinks it has (self.velocity) and which it will use for dynamical updates on subsequent steps.

0 commit comments

Comments
 (0)