Skip to content

Commit 478ef1f

Browse files
authored
Merge pull request #30 from kfu02/snd-qmix-ns
add snd for qmix non ps and indepedent agent env logic
2 parents c5229bf + f5255fe commit 478ef1f

File tree

5 files changed

+95
-71
lines changed

5 files changed

+95
-71
lines changed

baselines/QLearning/qmix.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,26 @@ class MixingNetwork(nn.Module):
5757

5858
@nn.compact
5959
def __call__(self, q_vals, states):
60-
60+
6161
n_agents, time_steps, batch_size = q_vals.shape
6262
q_vals = jnp.transpose(q_vals, (1, 2, 0)) # (time_steps, batch_size, n_agents)
63-
63+
6464
# hypernetwork
6565
w_1 = HyperNetwork(hidden_dim=self.hypernet_hidden_dim, output_dim=self.embedding_dim*n_agents, init_scale=self.init_scale)(states)
6666
b_1 = nn.Dense(self.embedding_dim, kernel_init=orthogonal(self.init_scale), bias_init=constant(0.))(states)
6767
w_2 = HyperNetwork(hidden_dim=self.hypernet_hidden_dim, output_dim=self.embedding_dim, init_scale=self.init_scale)(states)
6868
b_2 = HyperNetwork(hidden_dim=self.embedding_dim, output_dim=1, init_scale=self.init_scale)(states)
69-
69+
7070
# monotonicity and reshaping
7171
w_1 = jnp.abs(w_1.reshape(time_steps, batch_size, n_agents, self.embedding_dim))
7272
b_1 = b_1.reshape(time_steps, batch_size, 1, self.embedding_dim)
7373
w_2 = jnp.abs(w_2.reshape(time_steps, batch_size, self.embedding_dim, 1))
7474
b_2 = b_2.reshape(time_steps, batch_size, 1, 1)
75-
75+
7676
# mix
7777
hidden = nn.elu(jnp.matmul(q_vals[:, :, None, :], w_1) + b_1)
7878
q_tot = jnp.matmul(hidden, w_2) + b_2
79-
79+
8080
return q_tot.squeeze() # (time_steps, batch_size)
8181

8282

@@ -88,23 +88,23 @@ def __init__(self, start_e: float, end_e: float, duration: int):
8888
self.end_e = end_e
8989
self.duration = duration
9090
self.slope = (end_e - start_e) / duration
91-
91+
9292
@partial(jax.jit, static_argnums=0)
9393
def get_epsilon(self, t: int):
9494
e = self.slope*t + self.start_e
9595
return jnp.clip(e, self.end_e)
96-
96+
9797
@partial(jax.jit, static_argnums=0)
9898
def choose_actions(self, q_vals: dict, t: int, rng: chex.PRNGKey):
99-
99+
100100
def explore(q, eps, key):
101101
key_a, key_e = jax.random.split(key, 2) # a key for sampling random actions and one for picking
102-
greedy_actions = jnp.argmax(q, axis=-1) # get the greedy actions
102+
greedy_actions = jnp.argmax(q, axis=-1) # get the greedy actions
103103
random_actions = jax.random.randint(key_a, shape=greedy_actions.shape, minval=0, maxval=q.shape[-1]) # sample random actions
104104
pick_random = jax.random.uniform(key_e, greedy_actions.shape)<eps # pick which actions should be random
105105
chosed_actions = jnp.where(pick_random, random_actions, greedy_actions)
106106
return chosed_actions
107-
107+
108108
eps = self.get_epsilon(t)
109109
keys = dict(zip(q_vals.keys(), jax.random.split(rng, len(q_vals)))) # get a key for each agent
110110
chosen_actions = jax.tree.map(lambda q, k: explore(q, eps, k), q_vals, keys)
@@ -128,7 +128,7 @@ def make_train(config, log_train_env, log_test_env, viz_test_env, env_name="MPE_
128128
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
129129
)
130130

131-
131+
132132
def train(rng):
133133

134134
# INIT ENV
@@ -166,7 +166,7 @@ def _env_sample_step(env_state, unused):
166166
sample_sequence_length=1,
167167
period=1,
168168
)
169-
buffer_state = buffer.init(sample_traj_unbatched)
169+
buffer_state = buffer.init(sample_traj_unbatched)
170170

171171
# INIT NETWORK
172172
# init agent
@@ -176,7 +176,7 @@ def _env_sample_step(env_state, unused):
176176
else:
177177
exit("HyperMLP deprecated currently!") # TODO: to fix, pass in AGENT_HYPERNET_KWARGS
178178
# agent = AgentHyperMLP(action_dim=wrapped_env.max_action_space, hidden_dim=config["AGENT_HIDDEN_DIM"], init_scale=config['AGENT_INIT_SCALE'], hypernet_hidden_dim=config["AGENT_HYPERNET_KWARGS"]["HIDDEN_DIM"], hypernet_init_scale=config["AGENT_HYPERNET_KWARGS"]["INIT_SCALE"], dim_capabilities=log_train_env.dim_capabilities)
179-
else:
179+
else:
180180
if not config["AGENT_HYPERAWARE"]:
181181
agent = AgentRNN(action_dim=wrapped_env.max_action_space, hidden_dim=config["AGENT_HIDDEN_DIM"], init_scale=config['AGENT_INIT_SCALE'])
182182
else:
@@ -290,7 +290,7 @@ def _env_step(step_state, unused):
290290
dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones)
291291
# get the q_values from the agent netwoek
292292
hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_)
293-
# remove the dummy time_step dimension and index qs by the valid actions of each agent
293+
# remove the dummy time_step dimension and index qs by the valid actions of each agent
294294
valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions)
295295
# explore with epsilon greedy_exploration
296296
actions = explorer.choose_actions(valid_q_vals, t, key_a)
@@ -315,7 +315,7 @@ def _env_step(step_state, unused):
315315
env_state,
316316
init_obs,
317317
init_dones,
318-
hstate,
318+
hstate,
319319
_rng,
320320
time_state['timesteps'] # t is needed to compute epsilon
321321
)
@@ -360,12 +360,12 @@ def _loss_fn(params, target_network_params, init_hstate, learn_traj):
360360

361361
# compute q_tot with the mixer network
362362
chosen_action_qvals_mix = mixer.apply(
363-
params['mixer'],
363+
params['mixer'],
364364
jnp.stack(list(chosen_action_qvals.values())),
365365
learn_traj.obs['__all__'][:-1] # avoid last timestep
366366
)
367367
target_max_qvals_mix = mixer.apply(
368-
target_network_params['mixer'],
368+
target_network_params['mixer'],
369369
jnp.stack(list(target_max_qvals.values())),
370370
learn_traj.obs['__all__'][1:] # avoid first timestep
371371
)
@@ -399,7 +399,7 @@ def _td_lambda_target(ret, values):
399399
+ config['GAMMA']*(1-learn_traj.dones['__all__'][:-1])*target_max_qvals_mix
400400
)
401401
loss = jnp.mean((chosen_action_qvals_mix - jax.lax.stop_gradient(targets))**2)
402-
402+
403403
return loss
404404

405405

@@ -537,15 +537,15 @@ def _greedy_env_step(step_state, unused):
537537
env_state,
538538
init_obs,
539539
init_dones,
540-
hstate,
540+
hstate,
541541
_rng,
542542
)
543543
step_state, (rewards, dones, infos, viz_env_states, obs, hstate) = jax.lax.scan(
544544
_greedy_env_step, step_state, None, config["NUM_STEPS"]
545545
)
546546

547-
# get snd, NOTE: dim_c multiplier is currently hardcoded since it works for both fire and transport
548-
snd_value = snd(rollouts=obs, hiddens=hstate, dim_c=len(test_env.training_agents)*2, params=params, alg='qmix', agent=agent)
547+
# get snd, NOTE: dim_c multiplier is currently hardcoded since it works for both fire and transport
548+
snd_value = snd(rollouts=obs, hiddens=hstate, dim_c=len(test_env.training_agents)*2, params=params, alg='qmix' if config["PARAMETERS_SHARING"] else 'qmix_ns', agent=agent)
549549

550550
def fire_env_metrics(final_env_state):
551551
"""
@@ -635,7 +635,7 @@ def callback(timestep, val):
635635
print(f"Timestep: {timestep}, return: {val}")
636636
jax.debug.callback(callback, time_state['timesteps']*config['NUM_ENVS'], first_returns['__all__'].mean())
637637
return {"metrics": metrics, "viz_env_states": viz_env_states}
638-
638+
639639
time_state = {
640640
'timesteps':jnp.array(0),
641641
'updates': jnp.array(0)
@@ -662,7 +662,7 @@ def callback(timestep, val):
662662
_update_step, runner_state, None, config["NUM_UPDATES"]
663663
)
664664
return {'runner_state':runner_state, 'metrics':metrics}
665-
665+
666666
return train
667667

668668
@hydra.main(version_base=None, config_path="./config", config_name="config")
@@ -673,7 +673,7 @@ def main(config):
673673

674674
env_name = config["env"]["ENV_NAME"]
675675
alg_name = f'qmix_{"ps" if config["alg"].get("PARAMETERS_SHARING", True) else "ns"}'
676-
676+
677677
# smac init neeeds a scenario
678678
if 'smax' in env_name.lower():
679679
config['env']['ENV_KWARGS']['scenario'] = map_name_to_scenario(config['env']['MAP_NAME'])
@@ -688,7 +688,7 @@ def main(config):
688688
log_test_env = LogWrapper(viz_test_env)
689689

690690
config["alg"]["NUM_STEPS"] = config["alg"].get("NUM_STEPS", train_env.max_steps) # default steps defined by the env
691-
691+
692692
hyper_tag = "hyper" if config["alg"]["AGENT_HYPERAWARE"] else "normal"
693693
recurrent_tag = "RNN" if config["alg"]["AGENT_RECURRENT"] else "MLP"
694694
aware_tag = "aware" if config["env"]["ENV_KWARGS"]["capability_aware"] else "unaware"
@@ -714,12 +714,12 @@ def main(config):
714714
config=config,
715715
mode=config["WANDB_MODE"],
716716
)
717-
717+
718718
rng = jax.random.PRNGKey(config["SEED"])
719719
rngs = jax.random.split(rng, config["NUM_SEEDS"])
720720
train_vjit = jax.jit(jax.vmap(make_train(config["alg"], log_train_env, log_test_env, viz_test_env, env_name=config["env"]["ENV_NAME"])))
721721
outs = jax.block_until_ready(train_vjit(rngs))
722-
722+
723723
# save params
724724
if config['SAVE_PATH'] is not None:
725725

@@ -779,4 +779,4 @@ def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
779779

780780
if __name__ == "__main__":
781781
main()
782-
782+

jaxmarl/environments/mpe/simple.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Base class for MPE PettingZoo envs.
33
44
TODO: viz for communication env, e.g. crypto
@@ -52,6 +52,7 @@ def __init__(
5252
):
5353
self.test_env_flag = kwargs["test_env_flag"] if "test_env_flag" in kwargs else False
5454
self.test_capabilities = kwargs["test_capabilities"] if "test_capabilities" in kwargs else None
55+
self.independent_agents = kwargs["independent_agents"] if "independent_agents" in kwargs else None
5556

5657
# Agent and entity constants
5758
self.num_agents = num_agents
@@ -136,7 +137,7 @@ def __init__(
136137
self.agent_accels = kwargs["agent_accels"]
137138
# assert (len(self.agent_accels) >= self.num_agents), f"Not enough agent_accels, {len(self.agent_accels)} < {self.num_agents}"
138139
self.agent_accels = jnp.array(self.agent_accels)
139-
140+
140141
if "agent_capacities" in kwargs:
141142
self.agent_capacities = kwargs["agent_capacities"]
142143
self.agent_capacities = jnp.array(self.agent_capacities)
@@ -297,7 +298,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
297298
# if self.test_env_flag and self.test_capabilities is not None:
298299
# team_capabilities = jnp.asarray(self.test_capabilities)
299300

300-
301+
301302
agent_rads = self.agent_rads[selected_agents]
302303
agent_accels = self.agent_accels[selected_agents]
303304
agent_capacities = self.agent_capacities[selected_agents] if self.agent_capacities else np.zeros((self.num_agents, 2))
@@ -521,7 +522,7 @@ def map_bounds_reward(self, x: float):
521522
m = x < 1.0
522523
mr = (x - 0.9) * 10
523524
br = jnp.min(jnp.array([jnp.exp(2 * x - 2), 10]))
524-
return jax.lax.select(m, mr, br) * ~w
525+
return jax.lax.select(m, mr, br) * ~w
525526

526527

527528
if __name__ == "__main__":

jaxmarl/environments/mpe/simple_fire.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
fire_pos_dim = num_landmarks * 2
3131
fire_rad_dim = num_landmarks
3232
observation_spaces = {
33-
i:Box(-jnp.inf, jnp.inf, (pos_dim + vel_dim + self.dim_capabilities + fire_pos_dim + fire_rad_dim))
33+
i:Box(-jnp.inf, jnp.inf, (pos_dim + vel_dim + self.dim_capabilities + fire_pos_dim + fire_rad_dim))
3434
for i in agents
3535
}
3636

@@ -40,7 +40,7 @@ def __init__(
4040
# env specific parameters
4141
self.test_teams = jnp.array(kwargs["test_teams"]) if "test_teams" in kwargs else None
4242
self.fire_rad_range = kwargs["fire_rad_range"] if "fire_rad_range" in kwargs else [0.2, 0.3]
43-
43+
4444
# reward shaping
4545
self.fire_out_reward = kwargs["fire_out_reward"] if "fire_out_reward" in kwargs else 1
4646
self.uncovered_penalty_factor = kwargs["uncovered_penalty_factor"] if "uncovered_penalty_factor" in kwargs else 2
@@ -207,7 +207,7 @@ def _spawn_one_fire(carry, _):
207207

208208
# if new fire spawn is valid, add it to the fire list, and incr the fire index
209209
new_fires = jax.lax.cond(
210-
new_fire_valid,
210+
new_fire_valid,
211211
lambda: new_fire_added, # T
212212
lambda: existing_fires, # F
213213
)
@@ -240,18 +240,23 @@ def _spawn_one_fire(carry, _):
240240
]
241241
)
242242

243-
# randomly sample N_agents' capabilities from the possible agent pool (hence w/out replacement)
244-
selected_agents = jax.random.choice(key_c, self.num_agents, shape=(self.num_agents,), replace=False)
245-
agent_rads = self.agent_rads[selected_agents]
246-
agent_accels = self.agent_accels[selected_agents]
247-
248-
# unless a test distribution is provided and this is a test_env
249-
if self.test_env_flag and self.test_teams is not None:
250-
# pick one of the test teams at random
251-
selected_team = jax.random.choice(key_tt, self.test_teams.shape[0], shape=(1,))
252-
test_team = self.test_teams[selected_team].squeeze()
253-
agent_rads = test_team[0::2]
254-
agent_accels = test_team[1::2]
243+
if self.independent_agents:
244+
# if independent policies do not sample teams and capabilities, keep constant
245+
# NOTE: assumes that agent_rad and agent_accels are n_agent length
246+
agent_rads = self.agent_rads
247+
agent_accels = self.agent_accels
248+
else:
249+
# randomly sample N_agents' capabilities from the possible agent pool (hence w/out replacement)
250+
selected_agents = jax.random.choice(key_c, self.num_agents, shape=(self.num_agents,), replace=False)
251+
agent_rads = self.agent_rads[selected_agents]
252+
agent_accels = self.agent_accels[selected_agents]
253+
# unless a test distribution is provided and this is a test_env
254+
if self.test_env_flag and self.test_teams is not None:
255+
# pick one of the test teams at random
256+
selected_team = jax.random.choice(key_tt, self.test_teams.shape[0], shape=(1,))
257+
test_team = self.test_teams[selected_team].squeeze()
258+
agent_rads = test_team[0::2]
259+
agent_accels = test_team[1::2]
255260

256261
state = State(
257262
p_pos=p_pos,

0 commit comments

Comments
 (0)