38
38
'hidden_size_list' : [64 , 1 ],
39
39
'update_per_collect' : 200 ,
40
40
'batch_size' : 128 ,
41
- }, {
42
- 'type' : 'trex' ,
43
- 'exp_name' : 'cartpole_trex_offppo_seed0' ,
44
- 'min_snippet_length' : 5 ,
45
- 'max_snippet_length' : 100 ,
46
- 'checkpoint_min' : 0 ,
47
- 'checkpoint_max' : 6 ,
48
- 'checkpoint_step' : 6 ,
49
- 'learning_rate' : 1e-5 ,
50
- 'update_per_collect' : 1 ,
51
- 'expert_model_path' : 'cartpole_ppo_offpolicy_seed0' ,
52
- 'hidden_size_list' : [512 , 64 , 1 ],
53
- 'obs_shape' : 4 ,
54
- 'action_shape' : 2 ,
55
41
}
56
42
]
57
43
@@ -67,15 +53,9 @@ def test_irl(reward_model_config):
67
53
expert_data_path = 'expert_data.pkl'
68
54
state_dict = expert_policy .collect_mode .state_dict ()
69
55
config = deepcopy (cartpole_ppo_offpolicy_config ), deepcopy (cartpole_ppo_offpolicy_create_config )
70
- if reward_model_config .type == 'trex' :
71
- trex_config = [deepcopy (cartpole_trex_offppo_config ), deepcopy (cartpole_trex_offppo_create_config )]
72
- trex_config [0 ].reward_model = reward_model_config
73
- args = EasyDict ({'cfg' : deepcopy (trex_config ), 'seed' : 0 , 'device' : 'cpu' })
74
- trex_collecting_data (args = args )
75
- else :
76
- collect_demo_data (
77
- config , seed = 0 , state_dict = state_dict , expert_data_path = expert_data_path , collect_count = collect_count
78
- )
56
+ collect_demo_data (
57
+ config , seed = 0 , state_dict = state_dict , expert_data_path = expert_data_path , collect_count = collect_count
58
+ )
79
59
# irl + rl training
80
60
cp_cartpole_dqn_config = deepcopy (cartpole_dqn_config )
81
61
cp_cartpole_dqn_create_config = deepcopy (cartpole_dqn_create_config )
@@ -88,9 +68,6 @@ def test_irl(reward_model_config):
88
68
cp_cartpole_dqn_config .policy .collect .n_sample = 128
89
69
cooptrain_reward = True
90
70
pretrain_reward = False
91
- if reward_model_config .type == 'trex' :
92
- cooptrain_reward = False
93
- pretrain_reward = True
94
71
serial_pipeline_reward_model_offpolicy (
95
72
(cp_cartpole_dqn_config , cp_cartpole_dqn_create_config ),
96
73
seed = 0 ,
@@ -126,3 +103,31 @@ def test_ngu():
126
103
serial_pipeline_reward_model_offpolicy (config , seed = 0 , max_train_iter = 2 )
127
104
except Exception :
128
105
assert False , "pipeline fail"
106
+
107
+
108
+ @pytest .mark .unittest
109
+ def test_trex ():
110
+ exp_name = 'test_serial_pipeline_trex_expert'
111
+ config = [deepcopy (cartpole_ppo_offpolicy_config ), deepcopy (cartpole_ppo_offpolicy_create_config )]
112
+ config [0 ].policy .learn .learner .hook .save_ckpt_after_iter = 100
113
+ config [0 ].exp_name = exp_name
114
+ expert_policy = serial_pipeline (config , seed = 0 )
115
+
116
+ exp_name = 'test_serial_pipeline_trex_collect'
117
+ config = [deepcopy (cartpole_trex_offppo_config ), deepcopy (cartpole_trex_offppo_create_config )]
118
+ config [0 ].exp_name = exp_name
119
+ config [0 ].reward_model .exp_name = exp_name
120
+ config [0 ].reward_model .expert_model_path = 'test_serial_pipeline_trex_expert'
121
+ config [0 ].reward_model .checkpoint_max = 100
122
+ config [0 ].reward_model .checkpoint_step = 100
123
+ config [0 ].reward_model .num_snippets = 100
124
+ args = EasyDict ({'cfg' : deepcopy (config ), 'seed' : 0 , 'device' : 'cpu' })
125
+ trex_collecting_data (args = args )
126
+ try :
127
+ serial_pipeline_reward_model_offpolicy (
128
+ config , seed = 0 , max_train_iter = 1 , pretrain_reward = True , cooptrain_reward = False
129
+ )
130
+ except Exception :
131
+ assert False , "pipeline fail"
132
+ finally :
133
+ os .popen ('rm -rf test_serial_pipeline_trex*' )
0 commit comments