diff --git a/vmas/scenarios/transport.py b/vmas/scenarios/transport.py index a264a744..b781f784 100644 --- a/vmas/scenarios/transport.py +++ b/vmas/scenarios/transport.py @@ -81,8 +81,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.package_mass = kwargs.get("package_mass", 3) # partial obs - self.partial_observations = kwargs.get("partial_observations", False) - self.package_observation_radius = kwargs.get("package_observation_radius", 0.35) + self.partial_observations = kwargs.get("partial_observations", True) + self.package_observation_dist = kwargs.get("package_observation_dist", 0.35) # realism self.linear_friction = kwargs.get("linear_friction", 0.01) @@ -131,6 +131,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): # Add agents capabilities = [] # save capabilities for relative capabilities later + self.observation_sensors = [] # for partial observability for i in range(n_agents): max_linear_vel = self.default_agent_max_linear_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) max_angular_vel = self.default_agent_max_angular_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) @@ -152,6 +153,20 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): world.add_agent(agent) + # add the observation sensor if partial observability is turned on. + if self.partial_observations: + self.observation_sensors.append( + Landmark( + name=f'obs_sensor_agent_{i}', + collide=False, + shape=Sphere(radius=self.package_observation_dist+radius), + color=(0.827, 0.827, 0.827, 0.65), + movable=False, + ) + ) + world.add_landmark(self.observation_sensors[-1]) + + self.capabilities = torch.tensor(capabilities) # Add landmarks @@ -191,7 +206,7 @@ def reset_world_at(self, env_index: int = None): # only do this during batched resets! if not env_index: capabilities = [] # save capabilities for relative capabilities later - for agent in self.world.agents: + for i, agent in enumerate(self.world.agents): max_linear_vel = self.default_agent_max_linear_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) max_angular_vel = self.default_agent_max_angular_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max) @@ -203,6 +218,12 @@ def reset_world_at(self, env_index: int = None): agent.shape=Sphere(radius) agent.mass=mass + # spawn the sensor radius for each agent + if self.partial_observations: + self.observation_sensors[i].set_pos(self.world.agents[i].state.pos, env_index) + self.observation_sensors[i].shape = Sphere(self.package_observation_dist+radius) + + self.capabilities = torch.tensor(capabilities) # spawn goal at origin @@ -256,7 +277,7 @@ def reset_world_at(self, env_index: int = None): ), occupied_positions=package_occupied_pos, ) - + self.package_starting_dists = [] self.og_package_positions = [] for i, package in enumerate(self.packages): @@ -444,12 +465,17 @@ def partial_observation(self, agent: Agent): # get positions of all entities in this agent's reference frame package_obs = [] out_of_obs_val = -0.0001 # default value used for out-of-observation data in the observation vector + + # spawn the sensor radius for each agent + for i, agent_i_sensor in enumerate(self.observation_sensors): + agent_i_sensor.set_pos(self.world.agents[i].state.pos, None) + for i, package in enumerate(self.packages): # box starting position and goal position alway part of the observation package_obs.append(self.og_package_positions[i]) package_obs.append(package.on_goal.unsqueeze(-1)) - mask = (torch.linalg.vector_norm(package.state.pos - agent.state.pos, dim=-1) < self.package_observation_radius) + mask = self.world.is_overlapping(self.observation_sensors[i], package) pkg_state_vec = package.state.pos.clone() pkg_rot_vec = package.state.rot.clone() pkg_vel_vec = package.state.vel.clone() @@ -606,15 +632,6 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]": geoms: List[Geom] = [] if not self.partial_observations: return geoms - - for i, agent in enumerate(self.world.agents): - - obs_circle = rendering.make_circle(self.package_observation_radius, filled=True) - xform = rendering.Transform() - xform.set_translation(*agent.state.pos[env_index]) - obs_circle.add_attr(xform) - obs_circle.set_color(*(0.827, 0.827, 0.827, 0.65)) - geoms.append(obs_circle) return geoms