|
46 | 46 | import gymnasium as gym |
47 | 47 | from pettingzoo import AECEnv, ParallelEnv |
48 | 48 |
|
| 49 | +# TODO: implement these infos in savanna_safetygrid.py instead |
| 50 | +INFO_PIPELINE_CYCLE = "pipeline_cycle" |
| 51 | +INFO_EPISODE = "episode" |
| 52 | +INFO_ENV_LAYOUT_SEED = "env_layout_seed" |
| 53 | +INFO_STEP = "step" |
| 54 | +INFO_TEST_MODE = "test_mode" |
| 55 | + |
49 | 56 | PettingZooEnv = Union[AECEnv, ParallelEnv] |
50 | 57 | Environment = Union[gym.Env, PettingZooEnv] |
51 | 58 |
|
@@ -199,7 +206,6 @@ def sb3_agent_train_thread_entry_point( |
199 | 206 |
|
200 | 207 | model = model_constructor(env_wrapper, env_classname, agent_id, cfg) |
201 | 208 | env_wrapper.set_model(model) |
202 | | - self.model = model |
203 | 209 | model.learn(total_timesteps=num_total_steps, callback=checkpoint_callback) |
204 | 210 | env_wrapper.save_or_return_model(model, filename_timestamp_sufix_str) |
205 | 211 | except ( |
@@ -299,10 +305,11 @@ def get_action( |
299 | 305 | # action_space = self.env.action_space(self.id) |
300 | 306 | self.info = info |
301 | 307 |
|
302 | | - self.info["i_pipeline_cycle"] = pipeline_cycle |
303 | | - self.info["i_episode"] = episode |
304 | | - self.info["step"] = step |
305 | | - self.info["test_mode"] = test_mode |
| 308 | + self.info[INFO_PIPELINE_CYCLE] = pipeline_cycle |
| 309 | + self.info[INFO_EPISODE] = episode |
| 310 | + self.info[INFO_ENV_LAYOUT_SEED] = env_layout_seed |
| 311 | + self.info[INFO_STEP] = step |
| 312 | + self.info[INFO_TEST_MODE] = test_mode |
306 | 313 |
|
307 | 314 | self.infos[self.id] = self.info |
308 | 315 |
|
@@ -365,17 +372,21 @@ def env_post_reset_callback(self, states, infos, seed, options, *args, **kwargs) |
365 | 372 | i_episode = ( |
366 | 373 | self.next_episode_no - 1 |
367 | 374 | ) # cannot use env.get_next_episode_no() here since its counter is reset for each new env_layout_seed |
| 375 | + env_layout_seed = ( |
| 376 | + self.env.get_env_layout_seed() |
| 377 | + ) # no need to substract 1 here since env_layout_seed value is overridden in env_pre_reset_callback |
368 | 378 | step = 0 |
369 | 379 | test_mode = False |
370 | 380 |
|
371 | 381 | for ( |
372 | 382 | agent, |
373 | 383 | info, |
374 | 384 | ) in infos.items(): # TODO: move this code to savanna_safetygrid.py |
375 | | - info["i_pipeline_cycle"] = i_pipeline_cycle |
376 | | - info["i_episode"] = i_episode |
377 | | - info["step"] = 0 |
378 | | - info["test_mode"] = test_mode |
| 385 | + info[INFO_PIPELINE_CYCLE] = i_pipeline_cycle |
| 386 | + info[INFO_EPISODE] = i_episode |
| 387 | + info[INFO_ENV_LAYOUT_SEED] = env_layout_seed |
| 388 | + info[INFO_STEP] = 0 |
| 389 | + info[INFO_TEST_MODE] = test_mode |
379 | 390 |
|
380 | 391 | if self.model: |
381 | 392 | if hasattr(self.model.policy, "my_reset"): |
@@ -436,10 +447,11 @@ def parallel_env_post_step_callback( |
436 | 447 | done = terminateds[agent] or truncateds[agent] |
437 | 448 |
|
438 | 449 | # TODO: move this code to savanna_safetygrid.py |
439 | | - info["i_pipeline_cycle"] = i_pipeline_cycle |
440 | | - info["i_episode"] = i_episode |
441 | | - info["step"] = step |
442 | | - info["test_mode"] = test_mode |
| 450 | + info[INFO_PIPELINE_CYCLE] = i_pipeline_cycle |
| 451 | + info[INFO_EPISODE] = i_episode |
| 452 | + info[INFO_ENV_LAYOUT_SEED] = env_layout_seed |
| 453 | + info[INFO_STEP] = step |
| 454 | + info[INFO_TEST_MODE] = test_mode |
443 | 455 |
|
444 | 456 | agent_step_info = [ |
445 | 457 | agent, |
@@ -541,10 +553,11 @@ def sequential_env_post_step_callback( |
541 | 553 | test_mode = False |
542 | 554 |
|
543 | 555 | # TODO: move this code to savanna_safetygrid.py |
544 | | - self.info["i_pipeline_cycle"] = i_pipeline_cycle |
545 | | - self.info["i_episode"] = i_episode |
546 | | - self.info["step"] = step |
547 | | - self.info["test_mode"] = test_mode |
| 556 | + self.info[INFO_PIPELINE_CYCLE] = i_pipeline_cycle |
| 557 | + self.info[INFO_EPISODE] = i_episode |
| 558 | + self.info[INFO_ENV_LAYOUT_SEED] = env_layout_seed |
| 559 | + self.info[INFO_STEP] = step |
| 560 | + self.info[INFO_TEST_MODE] = test_mode |
548 | 561 |
|
549 | 562 | self.infos[self.id] = self.info |
550 | 563 |
|
|
0 commit comments