@@ -57,26 +57,26 @@ class MixingNetwork(nn.Module):
5757
5858 @nn .compact
5959 def __call__ (self , q_vals , states ):
60-
60+
6161 n_agents , time_steps , batch_size = q_vals .shape
6262 q_vals = jnp .transpose (q_vals , (1 , 2 , 0 )) # (time_steps, batch_size, n_agents)
63-
63+
6464 # hypernetwork
6565 w_1 = HyperNetwork (hidden_dim = self .hypernet_hidden_dim , output_dim = self .embedding_dim * n_agents , init_scale = self .init_scale )(states )
6666 b_1 = nn .Dense (self .embedding_dim , kernel_init = orthogonal (self .init_scale ), bias_init = constant (0. ))(states )
6767 w_2 = HyperNetwork (hidden_dim = self .hypernet_hidden_dim , output_dim = self .embedding_dim , init_scale = self .init_scale )(states )
6868 b_2 = HyperNetwork (hidden_dim = self .embedding_dim , output_dim = 1 , init_scale = self .init_scale )(states )
69-
69+
7070 # monotonicity and reshaping
7171 w_1 = jnp .abs (w_1 .reshape (time_steps , batch_size , n_agents , self .embedding_dim ))
7272 b_1 = b_1 .reshape (time_steps , batch_size , 1 , self .embedding_dim )
7373 w_2 = jnp .abs (w_2 .reshape (time_steps , batch_size , self .embedding_dim , 1 ))
7474 b_2 = b_2 .reshape (time_steps , batch_size , 1 , 1 )
75-
75+
7676 # mix
7777 hidden = nn .elu (jnp .matmul (q_vals [:, :, None , :], w_1 ) + b_1 )
7878 q_tot = jnp .matmul (hidden , w_2 ) + b_2
79-
79+
8080 return q_tot .squeeze () # (time_steps, batch_size)
8181
8282
@@ -88,23 +88,23 @@ def __init__(self, start_e: float, end_e: float, duration: int):
8888 self .end_e = end_e
8989 self .duration = duration
9090 self .slope = (end_e - start_e ) / duration
91-
91+
9292 @partial (jax .jit , static_argnums = 0 )
9393 def get_epsilon (self , t : int ):
9494 e = self .slope * t + self .start_e
9595 return jnp .clip (e , self .end_e )
96-
96+
9797 @partial (jax .jit , static_argnums = 0 )
9898 def choose_actions (self , q_vals : dict , t : int , rng : chex .PRNGKey ):
99-
99+
100100 def explore (q , eps , key ):
101101 key_a , key_e = jax .random .split (key , 2 ) # a key for sampling random actions and one for picking
102- greedy_actions = jnp .argmax (q , axis = - 1 ) # get the greedy actions
102+ greedy_actions = jnp .argmax (q , axis = - 1 ) # get the greedy actions
103103 random_actions = jax .random .randint (key_a , shape = greedy_actions .shape , minval = 0 , maxval = q .shape [- 1 ]) # sample random actions
104104 pick_random = jax .random .uniform (key_e , greedy_actions .shape )< eps # pick which actions should be random
105105 chosed_actions = jnp .where (pick_random , random_actions , greedy_actions )
106106 return chosed_actions
107-
107+
108108 eps = self .get_epsilon (t )
109109 keys = dict (zip (q_vals .keys (), jax .random .split (rng , len (q_vals )))) # get a key for each agent
110110 chosen_actions = jax .tree .map (lambda q , k : explore (q , eps , k ), q_vals , keys )
@@ -128,7 +128,7 @@ def make_train(config, log_train_env, log_test_env, viz_test_env, env_name="MPE_
128128 config ["TOTAL_TIMESTEPS" ] // config ["NUM_STEPS" ] // config ["NUM_ENVS" ]
129129 )
130130
131-
131+
132132 def train (rng ):
133133
134134 # INIT ENV
@@ -166,7 +166,7 @@ def _env_sample_step(env_state, unused):
166166 sample_sequence_length = 1 ,
167167 period = 1 ,
168168 )
169- buffer_state = buffer .init (sample_traj_unbatched )
169+ buffer_state = buffer .init (sample_traj_unbatched )
170170
171171 # INIT NETWORK
172172 # init agent
@@ -176,7 +176,7 @@ def _env_sample_step(env_state, unused):
176176 else :
177177 exit ("HyperMLP deprecated currently!" ) # TODO: to fix, pass in AGENT_HYPERNET_KWARGS
178178 # agent = AgentHyperMLP(action_dim=wrapped_env.max_action_space, hidden_dim=config["AGENT_HIDDEN_DIM"], init_scale=config['AGENT_INIT_SCALE'], hypernet_hidden_dim=config["AGENT_HYPERNET_KWARGS"]["HIDDEN_DIM"], hypernet_init_scale=config["AGENT_HYPERNET_KWARGS"]["INIT_SCALE"], dim_capabilities=log_train_env.dim_capabilities)
179- else :
179+ else :
180180 if not config ["AGENT_HYPERAWARE" ]:
181181 agent = AgentRNN (action_dim = wrapped_env .max_action_space , hidden_dim = config ["AGENT_HIDDEN_DIM" ], init_scale = config ['AGENT_INIT_SCALE' ])
182182 else :
@@ -290,7 +290,7 @@ def _env_step(step_state, unused):
290290 dones_ = jax .tree .map (lambda x : x [np .newaxis , :], last_dones )
291291 # get the q_values from the agent netwoek
292292 hstate , q_vals = homogeneous_pass (params , hstate , obs_ , dones_ )
293- # remove the dummy time_step dimension and index qs by the valid actions of each agent
293+ # remove the dummy time_step dimension and index qs by the valid actions of each agent
294294 valid_q_vals = jax .tree_util .tree_map (lambda q , valid_idx : q .squeeze (0 )[..., valid_idx ], q_vals , wrapped_env .valid_actions )
295295 # explore with epsilon greedy_exploration
296296 actions = explorer .choose_actions (valid_q_vals , t , key_a )
@@ -315,7 +315,7 @@ def _env_step(step_state, unused):
315315 env_state ,
316316 init_obs ,
317317 init_dones ,
318- hstate ,
318+ hstate ,
319319 _rng ,
320320 time_state ['timesteps' ] # t is needed to compute epsilon
321321 )
@@ -360,12 +360,12 @@ def _loss_fn(params, target_network_params, init_hstate, learn_traj):
360360
361361 # compute q_tot with the mixer network
362362 chosen_action_qvals_mix = mixer .apply (
363- params ['mixer' ],
363+ params ['mixer' ],
364364 jnp .stack (list (chosen_action_qvals .values ())),
365365 learn_traj .obs ['__all__' ][:- 1 ] # avoid last timestep
366366 )
367367 target_max_qvals_mix = mixer .apply (
368- target_network_params ['mixer' ],
368+ target_network_params ['mixer' ],
369369 jnp .stack (list (target_max_qvals .values ())),
370370 learn_traj .obs ['__all__' ][1 :] # avoid first timestep
371371 )
@@ -399,7 +399,7 @@ def _td_lambda_target(ret, values):
399399 + config ['GAMMA' ]* (1 - learn_traj .dones ['__all__' ][:- 1 ])* target_max_qvals_mix
400400 )
401401 loss = jnp .mean ((chosen_action_qvals_mix - jax .lax .stop_gradient (targets ))** 2 )
402-
402+
403403 return loss
404404
405405
@@ -537,15 +537,15 @@ def _greedy_env_step(step_state, unused):
537537 env_state ,
538538 init_obs ,
539539 init_dones ,
540- hstate ,
540+ hstate ,
541541 _rng ,
542542 )
543543 step_state , (rewards , dones , infos , viz_env_states , obs , hstate ) = jax .lax .scan (
544544 _greedy_env_step , step_state , None , config ["NUM_STEPS" ]
545545 )
546546
547- # get snd, NOTE: dim_c multiplier is currently hardcoded since it works for both fire and transport
548- snd_value = snd (rollouts = obs , hiddens = hstate , dim_c = len (test_env .training_agents )* 2 , params = params , alg = 'qmix' , agent = agent )
547+ # get snd, NOTE: dim_c multiplier is currently hardcoded since it works for both fire and transport
548+ snd_value = snd (rollouts = obs , hiddens = hstate , dim_c = len (test_env .training_agents )* 2 , params = params , alg = 'qmix' if config [ "PARAMETERS_SHARING" ] else 'qmix_ns' , agent = agent )
549549
550550 def fire_env_metrics (final_env_state ):
551551 """
@@ -635,7 +635,7 @@ def callback(timestep, val):
635635 print (f"Timestep: { timestep } , return: { val } " )
636636 jax .debug .callback (callback , time_state ['timesteps' ]* config ['NUM_ENVS' ], first_returns ['__all__' ].mean ())
637637 return {"metrics" : metrics , "viz_env_states" : viz_env_states }
638-
638+
639639 time_state = {
640640 'timesteps' :jnp .array (0 ),
641641 'updates' : jnp .array (0 )
@@ -662,7 +662,7 @@ def callback(timestep, val):
662662 _update_step , runner_state , None , config ["NUM_UPDATES" ]
663663 )
664664 return {'runner_state' :runner_state , 'metrics' :metrics }
665-
665+
666666 return train
667667
668668@hydra .main (version_base = None , config_path = "./config" , config_name = "config" )
@@ -673,7 +673,7 @@ def main(config):
673673
674674 env_name = config ["env" ]["ENV_NAME" ]
675675 alg_name = f'qmix_{ "ps" if config ["alg" ].get ("PARAMETERS_SHARING" , True ) else "ns" } '
676-
676+
677677 # smac init neeeds a scenario
678678 if 'smax' in env_name .lower ():
679679 config ['env' ]['ENV_KWARGS' ]['scenario' ] = map_name_to_scenario (config ['env' ]['MAP_NAME' ])
@@ -688,7 +688,7 @@ def main(config):
688688 log_test_env = LogWrapper (viz_test_env )
689689
690690 config ["alg" ]["NUM_STEPS" ] = config ["alg" ].get ("NUM_STEPS" , train_env .max_steps ) # default steps defined by the env
691-
691+
692692 hyper_tag = "hyper" if config ["alg" ]["AGENT_HYPERAWARE" ] else "normal"
693693 recurrent_tag = "RNN" if config ["alg" ]["AGENT_RECURRENT" ] else "MLP"
694694 aware_tag = "aware" if config ["env" ]["ENV_KWARGS" ]["capability_aware" ] else "unaware"
@@ -714,12 +714,12 @@ def main(config):
714714 config = config ,
715715 mode = config ["WANDB_MODE" ],
716716 )
717-
717+
718718 rng = jax .random .PRNGKey (config ["SEED" ])
719719 rngs = jax .random .split (rng , config ["NUM_SEEDS" ])
720720 train_vjit = jax .jit (jax .vmap (make_train (config ["alg" ], log_train_env , log_test_env , viz_test_env , env_name = config ["env" ]["ENV_NAME" ])))
721721 outs = jax .block_until_ready (train_vjit (rngs ))
722-
722+
723723 # save params
724724 if config ['SAVE_PATH' ] is not None :
725725
@@ -779,4 +779,4 @@ def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
779779
780780if __name__ == "__main__" :
781781 main ()
782-
782+
0 commit comments