diff --git a/vmas/scenarios/transport.py b/vmas/scenarios/transport.py index df60fce7..8e506bee 100644 --- a/vmas/scenarios/transport.py +++ b/vmas/scenarios/transport.py @@ -40,6 +40,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.package_length = kwargs.get("package_length", 0.4) self.package_rotatable = kwargs.get("package_rotatable", True) self.package_mass = kwargs.get("package_mass", 3) + # how far away the packages can spawn from the goal + self.min_pkg_goal_spawn_dist = kwargs.get("min_pkg_goal_spawn_dist", 0.01) # partial obs self.partial_observations = kwargs.get("partial_observations", False) @@ -52,15 +54,16 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): # TODO: implement automated domain randomization here? # rewards - self.agent_package_dist_reward_factor = kwargs.get("agent_package_dist_reward_factor", 0.1) - self.package_goal_dist_reward_factor = kwargs.get("package_goal_dist_reward_factor", 100) + self.agent_package_dist_reward_factor = kwargs.get("agent_package_dist_reward_factor", 0) + self.package_goal_dist_reward_factor = kwargs.get("package_goal_dist_reward_factor", 0) + self.agent_near_pkg_rew_factor = kwargs.get("agent_near_pkg_rew_factor", 0) self.min_collision_distance = 0.05 * self.default_agent_radius # default navigation collision dist is 5% of the agent radius - self.interagent_collision_penalty = kwargs.get("interagent_collision_penalty", -1) + self.interagent_collision_penalty = kwargs.get("interagent_collision_penalty", 0) assert self.interagent_collision_penalty <= 0, f"self.interagent_collision_penalty must be <= 0, current value is {self.interagent_collision_penalty}!" self.add_dense_reward = kwargs.get("add_dense_reward", True) - self.package_on_goal_reward_factor = kwargs.get("package_on_goal_reward_factor", 1.0) + self.package_on_goal_reward_factor = kwargs.get("package_on_goal_reward_factor", 0.0) self.agent_touching_package_reward_factor = kwargs.get("agent_touching_package_reward_factor", 0.0) self.time_penalty = kwargs.get("time_penalty", 0.0) @@ -168,7 +171,7 @@ def reset_world_at(self, env_index: int = None): self.world, env_index, min_dist_between_entities=max( - package.shape.circumscribed_radius() + goal.shape.radius + 0.01 + package.shape.circumscribed_radius() + goal.shape.radius + self.min_pkg_goal_spawn_dist for package in self.packages ), x_bounds=( @@ -253,17 +256,18 @@ def reward(self, agent: Agent): Color.GREEN.value, device=self.world.device, dtype=torch.float32 ) - # dense reward if self.add_dense_reward: + # reward for pushing the package closer to goal than previous step package_shaping = package.dist_to_goal * self.package_goal_dist_reward_factor self.rew[~package.on_goal] += ( package.global_shaping[~package.on_goal] - package_shaping[~package.on_goal] ) + # "global shaping" = the last package dist * goal_dist_rew_factor package.global_shaping = package_shaping # positive reward when the agent achieves the goal - self.rew[package.on_goal] += 1.0 * self.package_on_goal_reward_factor + # self.rew[package.on_goal] += 1.0 * self.package_on_goal_reward_factor _time_penalty += self.time_penalty # penalty (negative rew) for agent-agent collisions @@ -283,14 +287,19 @@ def reward(self, agent: Agent): distance <= self.min_collision_distance ] += self.interagent_collision_penalty - # reward for how close agents are to all packages + # reward agents for being near a package if self.add_dense_reward: for i, package in enumerate(self.packages): dist_to_pkg = torch.linalg.vector_norm(agent.state.pos - package.state.pos, dim=-1) - agent_touching_package=self.world.is_overlapping(package, agent) - self.rew += (-dist_to_pkg * self.agent_package_dist_reward_factor) + self.agent_touching_package_reward_factor * agent_touching_package + agent_diameter = torch.ones(dist_to_pkg.shape, device=self.world.device) * (agent.shape.radius * 2) + near_pkg = dist_to_pkg < 1.5 * agent_diameter + self.rew[near_pkg] += 1.0 * self.agent_near_pkg_rew_factor + # self.rew[~near_pkg] -= 1.0 * self.agent_near_pkg_rew_factor - return self.rew + agent.agent_collision_rew + _time_penalty + # agent_touching_package=self.world.is_overlapping(package, agent) + # self.rew += (-dist_to_pkg * self.agent_package_dist_reward_factor) + self.agent_touching_package_reward_factor * agent_touching_package + + return self.rew + agent.agent_collision_rew # + _time_penalty def info(self, agent: Agent): """