Skip to content

Commit 19126b7

Browse files
committed
Adding missing env_layout_seed info to SB3 instincts
1 parent 5e62002 commit 19126b7

File tree

5 files changed

+68
-29
lines changed

5 files changed

+68
-29
lines changed

aintelope/agents/a2c_agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
SB3BaseAgent,
2323
CustomCNN,
2424
PolicyWithConfigFactory,
25+
INFO_PIPELINE_CYCLE,
26+
INFO_EPISODE,
27+
INFO_ENV_LAYOUT_SEED,
28+
INFO_STEP,
29+
INFO_TEST_MODE,
2530
)
2631
from aintelope.aintelope_typing import ObservationFloat, PettingZooEnv
2732
from aintelope.training.dqn_training import Trainer
@@ -95,17 +100,19 @@ def forward(
95100
distribution = self._get_action_dist_from_latent(latent_pi)
96101

97102
# inserted code
98-
step = self.info["step"]
99-
episode = self.info["i_episode"]
100-
pipeline_cycle = self.info["i_pipeline_cycle"]
101-
test_mode = self.info["test_mode"]
103+
step = self.info[INFO_STEP]
104+
env_layout_seed = self.info[INFO_ENV_LAYOUT_SEED]
105+
episode = self.info[INFO_EPISODE]
106+
pipeline_cycle = self.info[INFO_PIPELINE_CYCLE]
107+
test_mode = self.info[INFO_TEST_MODE]
102108

103109
obs_nps = obs.detach().cpu().numpy()
104110
obs_np = obs_nps[0, :]
105111

106112
(override_type, _random) = self.expert.should_override(
107113
deterministic,
108114
step,
115+
env_layout_seed,
109116
episode,
110117
pipeline_cycle,
111118
test_mode,
@@ -116,6 +123,7 @@ def forward(
116123
obs_np,
117124
self.info,
118125
step,
126+
env_layout_seed,
119127
episode,
120128
pipeline_cycle,
121129
test_mode,

aintelope/agents/dqn_agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
SB3BaseAgent,
2323
CustomCNN,
2424
PolicyWithConfigFactory,
25+
INFO_PIPELINE_CYCLE,
26+
INFO_EPISODE,
27+
INFO_ENV_LAYOUT_SEED,
28+
INFO_STEP,
29+
INFO_TEST_MODE,
2530
)
2631
from aintelope.aintelope_typing import ObservationFloat, PettingZooEnv
2732
from aintelope.training.dqn_training import Trainer
@@ -74,17 +79,19 @@ def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
7479
actions = self.q_net._predict(obs, deterministic=deterministic)
7580

7681
# inserted code
77-
step = self.info["step"]
78-
episode = self.info["i_episode"]
79-
pipeline_cycle = self.info["i_pipeline_cycle"]
80-
test_mode = self.info["test_mode"]
82+
step = self.info[INFO_STEP]
83+
env_layout_seed = self.info[INFO_ENV_LAYOUT_SEED]
84+
episode = self.info[INFO_EPISODE]
85+
pipeline_cycle = self.info[INFO_PIPELINE_CYCLE]
86+
test_mode = self.info[INFO_TEST_MODE]
8187

8288
obs_nps = obs.detach().cpu().numpy()
8389
obs_np = obs_nps[0, :]
8490

8591
(override_type, _random) = self.expert.should_override(
8692
deterministic,
8793
step,
94+
env_layout_seed,
8895
episode,
8996
pipeline_cycle,
9097
test_mode,
@@ -95,6 +102,7 @@ def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
95102
obs_np,
96103
self.info,
97104
step,
105+
env_layout_seed,
98106
episode,
99107
pipeline_cycle,
100108
test_mode,

aintelope/agents/ppo_agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
CustomCNN,
2424
vec_env_args,
2525
PolicyWithConfigFactory,
26+
INFO_PIPELINE_CYCLE,
27+
INFO_EPISODE,
28+
INFO_ENV_LAYOUT_SEED,
29+
INFO_STEP,
30+
INFO_TEST_MODE,
2631
)
2732
from aintelope.aintelope_typing import ObservationFloat, PettingZooEnv
2833
from aintelope.training.dqn_training import Trainer
@@ -95,17 +100,19 @@ def forward(
95100
distribution = self._get_action_dist_from_latent(latent_pi)
96101

97102
# inserted code
98-
step = self.info["step"]
99-
episode = self.info["i_episode"]
100-
pipeline_cycle = self.info["i_pipeline_cycle"]
101-
test_mode = self.info["test_mode"]
103+
step = self.info[INFO_STEP]
104+
env_layout_seed = self.info[INFO_ENV_LAYOUT_SEED]
105+
episode = self.info[INFO_EPISODE]
106+
pipeline_cycle = self.info[INFO_PIPELINE_CYCLE]
107+
test_mode = self.info[INFO_TEST_MODE]
102108

103109
obs_nps = obs.detach().cpu().numpy()
104110
obs_np = obs_nps[0, :]
105111

106112
(override_type, _random) = self.expert.should_override(
107113
deterministic,
108114
step,
115+
env_layout_seed,
109116
episode,
110117
pipeline_cycle,
111118
test_mode,
@@ -116,6 +123,7 @@ def forward(
116123
obs_np,
117124
self.info,
118125
step,
126+
env_layout_seed,
119127
episode,
120128
pipeline_cycle,
121129
test_mode,

aintelope/agents/sb3_base_agent.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@
4646
import gymnasium as gym
4747
from pettingzoo import AECEnv, ParallelEnv
4848

49+
# TODO: implement these infos in savanna_safetygrid.py instead
50+
INFO_PIPELINE_CYCLE = "pipeline_cycle"
51+
INFO_EPISODE = "episode"
52+
INFO_ENV_LAYOUT_SEED = "env_layout_seed"
53+
INFO_STEP = "step"
54+
INFO_TEST_MODE = "test_mode"
55+
4956
PettingZooEnv = Union[AECEnv, ParallelEnv]
5057
Environment = Union[gym.Env, PettingZooEnv]
5158

@@ -199,7 +206,6 @@ def sb3_agent_train_thread_entry_point(
199206

200207
model = model_constructor(env_wrapper, env_classname, agent_id, cfg)
201208
env_wrapper.set_model(model)
202-
self.model = model
203209
model.learn(total_timesteps=num_total_steps, callback=checkpoint_callback)
204210
env_wrapper.save_or_return_model(model, filename_timestamp_sufix_str)
205211
except (
@@ -299,10 +305,11 @@ def get_action(
299305
# action_space = self.env.action_space(self.id)
300306
self.info = info
301307

302-
self.info["i_pipeline_cycle"] = pipeline_cycle
303-
self.info["i_episode"] = episode
304-
self.info["step"] = step
305-
self.info["test_mode"] = test_mode
308+
self.info[INFO_PIPELINE_CYCLE] = pipeline_cycle
309+
self.info[INFO_EPISODE] = episode
310+
self.info[INFO_ENV_LAYOUT_SEED] = env_layout_seed
311+
self.info[INFO_STEP] = step
312+
self.info[INFO_TEST_MODE] = test_mode
306313

307314
self.infos[self.id] = self.info
308315

@@ -365,17 +372,21 @@ def env_post_reset_callback(self, states, infos, seed, options, *args, **kwargs)
365372
i_episode = (
366373
self.next_episode_no - 1
367374
) # cannot use env.get_next_episode_no() here since its counter is reset for each new env_layout_seed
375+
env_layout_seed = (
376+
self.env.get_env_layout_seed()
377+
) # no need to substract 1 here since env_layout_seed value is overridden in env_pre_reset_callback
368378
step = 0
369379
test_mode = False
370380

371381
for (
372382
agent,
373383
info,
374384
) in infos.items(): # TODO: move this code to savanna_safetygrid.py
375-
info["i_pipeline_cycle"] = i_pipeline_cycle
376-
info["i_episode"] = i_episode
377-
info["step"] = 0
378-
info["test_mode"] = test_mode
385+
info[INFO_PIPELINE_CYCLE] = i_pipeline_cycle
386+
info[INFO_EPISODE] = i_episode
387+
info[INFO_ENV_LAYOUT_SEED] = env_layout_seed
388+
info[INFO_STEP] = 0
389+
info[INFO_TEST_MODE] = test_mode
379390

380391
if self.model:
381392
if hasattr(self.model.policy, "my_reset"):
@@ -436,10 +447,11 @@ def parallel_env_post_step_callback(
436447
done = terminateds[agent] or truncateds[agent]
437448

438449
# TODO: move this code to savanna_safetygrid.py
439-
info["i_pipeline_cycle"] = i_pipeline_cycle
440-
info["i_episode"] = i_episode
441-
info["step"] = step
442-
info["test_mode"] = test_mode
450+
info[INFO_PIPELINE_CYCLE] = i_pipeline_cycle
451+
info[INFO_EPISODE] = i_episode
452+
info[INFO_ENV_LAYOUT_SEED] = env_layout_seed
453+
info[INFO_STEP] = step
454+
info[INFO_TEST_MODE] = test_mode
443455

444456
agent_step_info = [
445457
agent,
@@ -541,10 +553,11 @@ def sequential_env_post_step_callback(
541553
test_mode = False
542554

543555
# TODO: move this code to savanna_safetygrid.py
544-
self.info["i_pipeline_cycle"] = i_pipeline_cycle
545-
self.info["i_episode"] = i_episode
546-
self.info["step"] = step
547-
self.info["test_mode"] = test_mode
556+
self.info[INFO_PIPELINE_CYCLE] = i_pipeline_cycle
557+
self.info[INFO_EPISODE] = i_episode
558+
self.info[INFO_ENV_LAYOUT_SEED] = env_layout_seed
559+
self.info[INFO_STEP] = step
560+
self.info[INFO_TEST_MODE] = test_mode
548561

549562
self.infos[self.id] = self.info
550563

aintelope/agents/sb3_instincts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def should_override(
101101
self,
102102
deterministic: bool = False, # This is set only during evaluation, not training and the meaning is that the agent is greedy - it takes the best action. It does NOT mean that the action is always same.
103103
step: int = 0,
104+
env_layout_seed: int = 0,
104105
episode: int = 0,
105106
pipeline_cycle: int = 0,
106107
test_mode: bool = False,
@@ -198,6 +199,7 @@ def get_action(
198199
observation=None,
199200
info: dict = {},
200201
step: int = 0,
202+
env_layout_seed: int = 0,
201203
episode: int = 0,
202204
pipeline_cycle: int = 0,
203205
test_mode: bool = False,

0 commit comments

Comments
 (0)