diff --git a/vmas/scenarios/transport.py b/vmas/scenarios/transport.py index db749b38..3850b0e9 100644 --- a/vmas/scenarios/transport.py +++ b/vmas/scenarios/transport.py @@ -449,7 +449,7 @@ def partial_observation(self, agent: Agent): package_obs.append(self.og_package_positions[i]) package_obs.append(package.on_goal.unsqueeze(-1)) - mask = (torch.linalg.vector_norm(package.state.pos - agent.state.pos, dim=-1) < self.package_observation_radius) + mask = (torch.linalg.vector_norm(package.state.pos - agent.state.pos, dim=-1) < (self.package_observation_radius + (self.package_length/2.0))) pkg_state_vec = package.state.pos.clone() pkg_rot_vec = package.state.rot.clone() pkg_vel_vec = package.state.vel.clone()