@@ -51,7 +51,7 @@ def _init_env(self) -> None:
5151 self ._seed ,
5252 self ._cfgs ,
5353 )
54- solver = PendulumSolver (device = self ._cfgs . train_cfgs . device )
54+ solver = PendulumSolver (device = self ._device )
5555 compensator = BarrierCompensator (
5656 obs_dim = self ._env .observation_space .shape [0 ],
5757 act_dim = self ._env .action_space .shape [0 ],
@@ -120,11 +120,18 @@ def _specific_save(self) -> None:
120120 os .makedirs (os .path .dirname (path ), exist_ok = True )
121121 joblib .dump (self ._env .gp_models , path )
122122
123- def _log_what_to_save (self ) -> dict [str , Any ]:
124- """Define what need to be saved below."""
123+ def _setup_torch_saver (self ) -> None :
124+ """Define what need to be saved below.
125+
126+ OmniSafe's main storage interface is based on PyTorch. If you need to save models in other
127+ formats, please use :meth:`_specific_save`.
128+ """
125129 what_to_save : dict [str , Any ] = {}
126130
127131 what_to_save ['pi' ] = self ._actor_critic .actor
128132 what_to_save ['compensator' ] = self ._env .compensator
133+ if self ._cfgs .algo_cfgs .obs_normalize :
134+ obs_normalizer = self ._env .save ()['obs_normalizer' ]
135+ what_to_save ['obs_normalizer' ] = obs_normalizer
129136
130137 self ._logger .setup_torch_saver (what_to_save )
0 commit comments