Skip to content

Commit 2f08d97

Browse files
committed
fix: fix compensator saving
1 parent ff507f1 commit 2f08d97

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

omnisafe/algorithms/off_policy/ddpg_cbf.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _init_env(self) -> None:
5151
self._seed,
5252
self._cfgs,
5353
)
54-
solver = PendulumSolver(device=self._cfgs.train_cfgs.device)
54+
solver = PendulumSolver(device=self._device)
5555
compensator = BarrierCompensator(
5656
obs_dim=self._env.observation_space.shape[0],
5757
act_dim=self._env.action_space.shape[0],
@@ -120,11 +120,18 @@ def _specific_save(self) -> None:
120120
os.makedirs(os.path.dirname(path), exist_ok=True)
121121
joblib.dump(self._env.gp_models, path)
122122

123-
def _log_what_to_save(self) -> dict[str, Any]:
124-
"""Define what need to be saved below."""
123+
def _setup_torch_saver(self) -> None:
124+
"""Define what need to be saved below.
125+
126+
OmniSafe's main storage interface is based on PyTorch. If you need to save models in other
127+
formats, please use :meth:`_specific_save`.
128+
"""
125129
what_to_save: dict[str, Any] = {}
126130

127131
what_to_save['pi'] = self._actor_critic.actor
128132
what_to_save['compensator'] = self._env.compensator
133+
if self._cfgs.algo_cfgs.obs_normalize:
134+
obs_normalizer = self._env.save()['obs_normalizer']
135+
what_to_save['obs_normalizer'] = obs_normalizer
129136

130137
self._logger.setup_torch_saver(what_to_save)

0 commit comments

Comments
 (0)