Skip to content

Commit 39f5428

Browse files
committed
Added AlberDice RWARE.
1 parent e6df2dd commit 39f5428

File tree

7 files changed

+497
-0
lines changed

7 files changed

+497
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
6+
class MultiAgentEnv(object):
7+
8+
def step(self, actions):
9+
"""Returns reward, terminated, info."""
10+
raise NotImplementedError
11+
12+
def get_obs(self):
13+
"""Returns all agent observations in a list."""
14+
raise NotImplementedError
15+
16+
def get_obs_agent(self, agent_id):
17+
"""Returns observation for agent_id."""
18+
raise NotImplementedError
19+
20+
def get_obs_size(self):
21+
"""Returns the size of the observation."""
22+
raise NotImplementedError
23+
24+
def get_state(self):
25+
"""Returns the global state."""
26+
raise NotImplementedError
27+
28+
def get_state_size(self):
29+
"""Returns the size of the global state."""
30+
raise NotImplementedError
31+
32+
def get_avail_actions(self):
33+
"""Returns the available actions of all agents in a list."""
34+
raise NotImplementedError
35+
36+
def get_avail_agent_actions(self, agent_id):
37+
"""Returns the available actions for agent_id."""
38+
raise NotImplementedError
39+
40+
def get_total_actions(self):
41+
"""Returns the total number of actions an agent could ever take."""
42+
raise NotImplementedError
43+
44+
def reset(self):
45+
"""Returns initial observations and states."""
46+
raise NotImplementedError
47+
48+
def render(self):
49+
raise NotImplementedError
50+
51+
def close(self):
52+
raise NotImplementedError
53+
54+
def seed(self):
55+
raise NotImplementedError
56+
57+
def save_replay(self):
58+
"""Save a replay."""
59+
raise NotImplementedError
60+
61+
def get_env_info(self):
62+
env_info = {"state_shape": self.get_state_size(),
63+
"obs_shape": self.get_obs_size(),
64+
"n_actions": self.get_total_actions(),
65+
"n_agents": self.n_agents,
66+
"episode_limit": self.episode_limit,
67+
"unit_dim": self.unit_dim}
68+
return env_info
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
from dataclasses import dataclass
3+
4+
DIRECTION_UP = 0
5+
DIRECTION_DOWN = 1
6+
DIRECTION_LEFT = 2
7+
DIRECTION_RIGHT = 3
8+
9+
10+
@dataclass
11+
class NearInformation:
12+
x: int
13+
y: int
14+
is_agent: bool
15+
agent_direction: int
16+
is_shelf: bool
17+
is_requested_shelf: bool
18+
19+
20+
@dataclass
21+
class Observation:
22+
x: int
23+
y: int
24+
is_carrying: bool
25+
direction: int
26+
is_path_location: bool
27+
near_info: list
28+
29+
30+
class ObservationParser:
31+
32+
@staticmethod
33+
def chunks(lst, n):
34+
"""Yield successive n-sized chunks from lst."""
35+
for i in range(0, len(lst), n):
36+
yield lst[i:i + n]
37+
38+
@staticmethod
39+
def parse(obs):
40+
parsed_obs = Observation(
41+
x=obs[0],
42+
y=obs[1],
43+
is_carrying=obs[2] == 1.0,
44+
direction=int(np.argmax(obs[3:7])),
45+
is_path_location=obs[7] == 1.0,
46+
near_info=ObservationParser.parse_near_info(obs)
47+
)
48+
return parsed_obs
49+
50+
@staticmethod
51+
def parse_near_info(obs):
52+
agent_x = obs[0]
53+
agent_y = obs[1]
54+
55+
near_info = []
56+
infos = list(ObservationParser.chunks(obs[8:], 7))
57+
58+
for i, info in enumerate(infos):
59+
row = i // 3
60+
col = i % 3
61+
near_info.append(NearInformation(
62+
x=agent_x - 1 + row,
63+
y=agent_y - 1 + col,
64+
is_agent=info[0] == 1.0,
65+
agent_direction=int(np.argmax(info[1:5])),
66+
is_shelf=info[5] == 1.0,
67+
is_requested_shelf=info[6] == 1.0
68+
))
69+
70+
return near_info
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import math
2+
from statistics import mean
3+
4+
from .observation_parser import NearInformation, Observation, ObservationParser
5+
6+
7+
class RewardCalculator:
8+
9+
@staticmethod
10+
def position_reward(env, x, y):
11+
max_dist = []
12+
dist = []
13+
14+
for goal in env.goals:
15+
goal_x = goal[0]
16+
goal_y = goal[1]
17+
dist.append(math.hypot(goal_x - x, goal_y - y))
18+
max_dist.append(math.hypot(goal_x - 0, goal_y - 0))
19+
return 0.0005 * (mean(max_dist) - mean(dist)) / mean(max_dist)
20+
21+
@staticmethod
22+
def is_center_shelf(obs: Observation, is_requested: bool) -> bool:
23+
center_info: NearInformation = obs.near_info[4]
24+
return center_info.is_shelf and center_info.is_requested_shelf
25+
26+
@staticmethod
27+
def find_requested_shelf(obs: Observation) -> NearInformation:
28+
near_info = obs.near_info
29+
info: NearInformation
30+
for info in near_info:
31+
if info.is_shelf and info.is_requested_shelf:
32+
return info
33+
return None
34+
35+
@staticmethod
36+
def calculate(env, reward, prev_obs, obs):
37+
obs: Observation = ObservationParser.parse(obs)
38+
39+
# requested shelf
40+
if RewardCalculator.is_center_shelf(obs, True):
41+
if obs.is_carrying:
42+
reward += 0.006
43+
else:
44+
reward += 0.003
45+
46+
# reward += RewardCalculator.position_reward(env, obs.x, obs.y)
47+
#
48+
# # non requested shelf
49+
# if RewardCalculator.is_center_shelf(obs, False):
50+
# if obs.is_carrying:
51+
# reward -= 0.003
52+
# else:
53+
# reward -= 0.0015
54+
#
55+
# reward -= RewardCalculator.position_reward(env, obs.x, obs.y)
56+
57+
# find out requested item
58+
if RewardCalculator.find_requested_shelf(obs) is not None:
59+
reward += 0.001
60+
61+
return reward
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from enum import Enum
2+
3+
4+
class Action(Enum):
5+
NOOP = 0
6+
FORWARD = 1
7+
LEFT = 2
8+
RIGHT = 3
9+
TOGGLE_LOAD = 4
10+
11+
12+
class Direction(Enum):
13+
UP = 0
14+
DOWN = 1
15+
LEFT = 2
16+
RIGHT = 3

0 commit comments

Comments
 (0)