35
35
TensorDictReplayBuffer ,
36
36
)
37
37
from torchrl .envs .libs .gym import _has_gym
38
- from torchrl .trainers import Recorder , Trainer
38
+ from torchrl .trainers import LogValidationReward , Trainer
39
39
from torchrl .trainers .helpers import transformed_env_constructor
40
40
from torchrl .trainers .trainers import (
41
41
_has_tqdm ,
42
42
_has_ts ,
43
43
BatchSubSampler ,
44
44
CountFramesLog ,
45
- LogReward ,
45
+ LogScalar ,
46
46
mask_batch ,
47
47
OptimizerHook ,
48
48
ReplayBufferTrainer ,
@@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar):
638
638
trainer = mocking_trainer ()
639
639
trainer .collected_frames = 0
640
640
641
- log_reward = LogReward (logname , log_pbar = pbar )
641
+ log_reward = LogScalar (logname , log_pbar = pbar )
642
642
trainer .register_op ("pre_steps_log" , log_reward )
643
643
td = TensorDict ({REWARD_KEY : torch .ones (3 )}, [3 ])
644
644
trainer ._pre_steps_log_hook (td )
@@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar):
654
654
trainer = mocking_trainer ()
655
655
trainer .collected_frames = 0
656
656
657
- log_reward = LogReward (logname , log_pbar = pbar )
657
+ log_reward = LogScalar (logname , log_pbar = pbar )
658
658
log_reward .register (trainer )
659
659
td = TensorDict ({REWARD_KEY : torch .ones (3 )}, [3 ])
660
660
trainer ._pre_steps_log_hook (td )
@@ -873,7 +873,7 @@ def test_recorder(self, N=8):
873
873
logger = logger ,
874
874
)()
875
875
876
- recorder = Recorder (
876
+ recorder = LogValidationReward (
877
877
record_frames = args .record_frames ,
878
878
frame_skip = args .frame_skip ,
879
879
policy_exploration = None ,
@@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8):
919
919
os .environ ["CKPT_BACKEND" ] = backend
920
920
state_dict_has_been_called = [False ]
921
921
load_state_dict_has_been_called = [False ]
922
- Recorder .state_dict , Recorder_state_dict = _fun_checker (
923
- Recorder .state_dict , state_dict_has_been_called
922
+ LogValidationReward .state_dict , Recorder_state_dict = _fun_checker (
923
+ LogValidationReward .state_dict , state_dict_has_been_called
924
+ )
925
+ (LogValidationReward .load_state_dict , Recorder_load_state_dict ,) = _fun_checker (
926
+ LogValidationReward .load_state_dict , load_state_dict_has_been_called
924
927
)
925
- (
926
- Recorder .load_state_dict ,
927
- Recorder_load_state_dict ,
928
- ) = _fun_checker (Recorder .load_state_dict , load_state_dict_has_been_called )
929
928
930
929
args = self ._get_args ()
931
930
@@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
948
947
)()
949
948
environment .rollout (2 )
950
949
951
- recorder = Recorder (
950
+ recorder = LogValidationReward (
952
951
record_frames = args .record_frames ,
953
952
frame_skip = args .frame_skip ,
954
953
policy_exploration = None ,
@@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname):
969
968
assert recorder2 ._count == 8
970
969
assert state_dict_has_been_called [0 ]
971
970
assert load_state_dict_has_been_called [0 ]
972
- Recorder .state_dict = Recorder_state_dict
973
- Recorder .load_state_dict = Recorder_load_state_dict
971
+ LogValidationReward .state_dict = Recorder_state_dict
972
+ LogValidationReward .load_state_dict = Recorder_load_state_dict
974
973
975
974
976
975
def test_updateweights ():
0 commit comments