Skip to content

Commit be03aa9

Browse files
committed
change clear data from pipeline to RM && add ngu to new entry
1 parent 822d7a4 commit be03aa9

20 files changed

+420
-277
lines changed

ding/entry/serial_entry_reward_model_offpolicy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ def serial_pipeline_reward_model_offpolicy(
115115
# update reward_model, when you want to train reward_model inloop
116116
if cooptrain_reward:
117117
reward_model.train()
118-
# clear buffer per fix iters to make sure replay buffer's data count isn't too few.
119-
if hasattr(cfg.reward_model, 'clear_buffer_per_iters') and count % cfg.reward_model.clear_buffer_per_iters == 0:
120-
reward_model.clear_data()
118+
# clear buffer per fix iters to make sure replay buffer's data count isn't too few.
119+
reward_model.clear_data(iter=count)
121120
# Learn policy from collected data
122121
for i in range(cfg.policy.learn.update_per_collect):
123122
# Learner will train ``update_per_collect`` times in one iteration.

ding/entry/serial_entry_reward_model_onpolicy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ def serial_pipeline_reward_model_onpolicy(
114114
# update reward_model
115115
if cooptrain_reward:
116116
reward_model.train()
117-
if hasattr(cfg.reward_model, 'clear_buffer_per_iters') and count % cfg.reward_model.clear_buffer_per_iters == 0:
118-
reward_model.clear_data()
117+
reward_model.clear_data(iter=count)
119118
# Learn policy from collected data
120119
for i in range(cfg.policy.learn.update_per_collect):
121120
# Learner will train ``update_per_collect`` times in one iteration.

ding/entry/tests/test_serial_entry_reward_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
1111
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
1212
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
13+
from dizoo.classic_control.cartpole.config.cartpole_ngu_config import cartpole_ngu_config, cartpole_ngu_create_config
1314
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
1415
serial_pipeline_reward_model_onpolicy
1516
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
@@ -116,3 +117,12 @@ def test_icm():
116117
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2)
117118
except Exception:
118119
assert False, "pipeline fail"
120+
121+
122+
@pytest.mark.unittest
123+
def test_ngu():
124+
config = [deepcopy(cartpole_ngu_config), deepcopy(cartpole_ngu_create_config)]
125+
try:
126+
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2)
127+
except Exception:
128+
assert False, "pipeline fail"

ding/policy/ngu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _init_collect(self) -> None:
431431
# epsilon=0.4, alpha=9
432432
self.eps = {i: 0.4 ** (1 + 8 * i / (self._cfg.collect.env_num - 1)) for i in range(self._cfg.collect.env_num)}
433433

434-
def _forward_collect(self, data: dict) -> dict:
434+
def _forward_collect(self, data: dict, eps: Optional[float]) -> dict:
435435
r"""
436436
Overview:
437437
Collect output according to eps_greedy plugin

ding/reward_model/base_reward_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def collect_data(self, data) -> None:
6060
raise NotImplementedError()
6161

6262
@abstractmethod
63-
def clear_data(self) -> None:
63+
def clear_data(self, iter: int) -> None:
6464
"""
6565
Overview:
6666
Clearing training data. \

ding/reward_model/gail_irl_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,14 @@ def collect_data(self, data: list) -> None:
212212
data = torch.unbind(data, dim=0)
213213
self.train_data.extend(data)
214214

215-
def clear_data(self) -> None:
215+
def clear_data(self, iter: int) -> None:
216216
"""
217217
Overview:
218218
Clearing training data. \
219219
This is a side effect function which clears the data attribute in ``self``
220220
"""
221-
self.train_data.clear()
221+
assert hasattr(
222+
self.cfg, 'clear_buffer_per_iters'
223+
), "Reward Model does not have clear_buffer_per_iters, Clear failed"
224+
if iter % self.cfg.clear_buffer_per_iters == 0:
225+
self.train_data.clear()

ding/reward_model/guided_cost_reward_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def collect_data(self, data) -> None:
132132
# if online_net is trained continuously, there should be some implementations in collect_data method
133133
pass
134134

135-
def clear_data(self):
135+
def clear_data(self, iter: int):
136136
"""
137137
Overview:
138138
Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, \

ding/reward_model/icm_reward_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,15 @@ def collect_data(self, data: list) -> None:
175175
self.train_next_states.extend(next_states)
176176
self.train_actions.extend(actions)
177177

178-
def clear_data(self) -> None:
179-
self.train_data.clear()
180-
self.train_states.clear()
181-
self.train_next_states.clear()
182-
self.train_actions.clear()
178+
def clear_data(self, iter: int) -> None:
179+
assert hasattr(
180+
self.cfg, 'clear_buffer_per_iters'
181+
), "Reward Model does not have clear_buffer_per_iters, Clear failed"
182+
if iter % self.cfg.clear_buffer_per_iters == 0:
183+
self.train_data.clear()
184+
self.train_states.clear()
185+
self.train_next_states.clear()
186+
self.train_actions.clear()
183187

184188
def state_dict(self) -> Dict:
185189
return self.reward_model.state_dict()

ding/reward_model/ngu_reward_model.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import random
3+
from typing import Any
34

45
import numpy as np
56
import torch
@@ -408,3 +409,92 @@ def fusion_reward(
408409
int(data[i]['beta'][j])]
409410

410411
return data, estimate_cnt
412+
413+
414+
@REWARD_MODEL_REGISTRY.register('ngu-reward')
415+
class NGURewardModel(BaseRewardModel):
416+
r"""
417+
Overview:
418+
The unifying reward for ngu which combined rnd-ngu and episodic
419+
The corresponding paper is `never give up: learning directed exploration strategies`.
420+
"""
421+
config = dict(
422+
type='ngu-reward',
423+
policy_nstep=5,
424+
collect_env_num=8,
425+
rnd_reward_model=dict(
426+
intrinsic_reward_type='add',
427+
learning_rate=5e-4,
428+
obs_shape=4,
429+
action_shape=2,
430+
batch_size=128, # transitions
431+
update_per_collect=10,
432+
only_use_last_five_frames_for_icm_rnd=False,
433+
clear_buffer_per_iters=10,
434+
nstep=5,
435+
hidden_size_list=[128, 128, 64],
436+
type='rnd-ngu',
437+
),
438+
episodic_reward_model=dict(
439+
last_nonzero_reward_rescale=False,
440+
last_nonzero_reward_weight=1,
441+
intrinsic_reward_type='add',
442+
learning_rate=5e-4,
443+
obs_shape=4,
444+
action_shape=2,
445+
batch_size=128, # transitions
446+
update_per_collect=10,
447+
only_use_last_five_frames_for_icm_rnd=False,
448+
clear_buffer_per_iters=10,
449+
nstep=5,
450+
hidden_size_list=[128, 128, 64],
451+
type='episodic',
452+
),
453+
)
454+
455+
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None:
456+
super(NGURewardModel).__init__()
457+
self.cfg = config
458+
self.tb_logger = tb_logger
459+
self.estimate_cnt = 0
460+
self.rnd_reward_model = RndNGURewardModel(config.rnd_reward_model, device, tb_logger)
461+
self.episodic_reward_model = EpisodicNGURewardModel(config.episodic_reward_model, device, tb_logger)
462+
463+
def train(self) -> None:
464+
self.rnd_reward_model.train()
465+
self.episodic_reward_model.train()
466+
467+
def estimate(self, data: list) -> dict:
468+
469+
# estimate reward
470+
rnd_reward = self.rnd_reward_model.estimate(data)
471+
episodic_reward = self.episodic_reward_model.estimate(data)
472+
473+
# combine reward
474+
train_data_augumented, self.estimate_cnt = self.episodic_reward_model.fusion_reward(
475+
data,
476+
episodic_reward,
477+
rnd_reward,
478+
nstep=self.cfg.policy_nstep,
479+
collector_env_num=self.cfg.collect_env_num,
480+
tb_logger=self.tb_logger,
481+
estimate_cnt=self.estimate_cnt
482+
)
483+
484+
return train_data_augumented
485+
486+
def collect_data(self, data) -> None:
487+
self.rnd_reward_model.collect_data(data)
488+
self.episodic_reward_model.collect_data(data)
489+
490+
def clear_data(self, iter: int) -> None:
491+
assert hasattr(
492+
self.cfg.rnd_reward_model, 'clear_buffer_per_iters'
493+
), "RND Reward Model does not have clear_buffer_per_iters, Clear failed"
494+
assert hasattr(
495+
self.cfg.episodic_reward_model, 'clear_buffer_per_iters'
496+
), "Episodic Reward Model does not have clear_buffer_per_iters, Clear failed"
497+
if iter % self.cfg.rnd_reward_model.clear_buffer_per_iters == 0:
498+
self.rnd_reward_model.clear_data()
499+
if iter % self.cfg.episodic_reward_model.clear_buffer_per_iters == 0:
500+
self.episodic_reward_model.clear_data()

ding/reward_model/pdeil_irl_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,14 @@ def collect_data(self, item: list):
210210
"""
211211
self.train_data.extend(item)
212212

213-
def clear_data(self):
213+
def clear_data(self, iter: int):
214214
"""
215215
Overview:
216216
Clearing training data. \
217217
This is a side effect function which clears the data attribute in ``self``
218218
"""
219-
self.train_data.clear()
219+
assert hasattr(
220+
self.cfg, 'clear_buffer_per_iters'
221+
), "Reward Model does not have clear_buffer_per_iters, Clear failed"
222+
if iter % self.cfg.clear_buffer_per_iters == 0:
223+
self.train_data.clear()

0 commit comments

Comments
 (0)