-
Notifications
You must be signed in to change notification settings - Fork 820
add RL4LMs summarization #1078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dwyzzy
wants to merge
39
commits into
PaddlePaddle:develop
Choose a base branch
from
dwyzzy:sentence_review_summarization
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add RL4LMs summarization #1078
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
a1a4c4b
init file using files from RL4LMS
dwyzzy e706ed4
benchmark of RL4LMs v0.0
dwyzzy f816028
benchmark of RL4LMs v0.1
dwyzzy 0293734
fix pg reward bug, remove no use warmstartup
dwyzzy 02efdd9
merge models and buffers, add README.md
dwyzzy 89c4efb
simplified code v0.0
dwyzzy 0b69359
remove distribution_wrapper.py and sample_util.py
dwyzzy 23735cb
remove EvaluateActionsOutput, ValueOutput and PolicyOutput
dwyzzy bbdd102
use Reviewer and ReviewerGroup instead of Env
dwyzzy a9aef6b
use Reviewer and ReviewerGroup instead of Env (parl parallel)
dwyzzy bf3c625
use Reviewer and ReviewerGroup instead of Env (parl parallel version)
dwyzzy d452685
review using sentence (parl parallel version)
dwyzzy b943f1c
remove some '**config' and change rollout util
dwyzzy 086ce6f
use instructor instead of reviewer, add examiner
dwyzzy 3acf2c3
add requirements.txt
dwyzzy 090b190
change code style
dwyzzy 78f44b8
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay b66f07e
change train.py style
dwyzzy 0d8af33
Merge remote-tracking branch 'rl4lm_parl/sentence_review_summarizatio…
dwyzzy 337ac75
change style
dwyzzy d0ced44
change style
dwyzzy 151fcea
change code style(add copyright)
dwyzzy f91d2c9
bring for-batch-rollout loop out of rl4lms_ppo
dwyzzy dc1d835
change name of policy/value , obs-preprocess and add-to-buffer
dwyzzy a23e8fe
change config structure
dwyzzy c2be52f
change ppo code style according to parl ppo
dwyzzy b34ea18
yapf code style
dwyzzy 02c8956
change code for PARL-RL4LMs summarization version 0.1
dwyzzy 760cc9d
change code style of PARL-RL4LMs summarization version 0.1
dwyzzy 1770e45
change unreasonable name to n_steps_per_instructor in config
dwyzzy b9c3e5c
add object for all classes, adjust add-to-buffer structure
dwyzzy 59e02fa
change t5_ppo_config and README
dwyzzy 4cd67e2
yapf code style
dwyzzy 6af0e40
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay da94226
change buffer add(), add save/load
dwyzzy 68ec090
yapf code style
dwyzzy 2b82da6
Merge remote-tracking branch 'rl4lm_parl/sentence_review_summarizatio…
dwyzzy 21e99e8
evaluate at beginning
dwyzzy 1704e4f
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,37 @@ | ||
| ## Reproduce (Reconfiguration) Summarization in RL4LMs using PARL | ||
| ## Reproduce Summarization-RLHF in RL4LMs using PARL | ||
|
|
||
| > Paper: [Is Reinforcement Learning (Not) for Natural Language Processing: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization](https://arxiv.org/abs/2210.01241) | ||
| > | ||
| > Official code: [RL4LMs](https://github.com/allenai/RL4LMs) | ||
| > | ||
| > Other code referenced: [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | ||
|
|
||
| ### Background | ||
|
|
||
| - Summarization task in NLP: Summarization is the task of producing a shorter version | ||
| of one document that preserves most of the input's meaning. | ||
| - RLHF: The abbreviation of Reinforcement Learning with Human Feedback, which uses human knowledge to train RL algorithms. | ||
| More information is available in the Hugging Face blog [Illustrating Reinforcement Learning from Human Feedback (RLHF)](https://huggingface.co/blog/rlhf) | ||
|
|
||
| ### Main contribution | ||
|
|
||
| - Change from **\{ trainer: \{ ppo: \{ env, rollout_buffer, policy/model \} \} \}** to | ||
| **\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according to PARL architecture. | ||
| - Use Parl parallel Training | ||
| - Build new Summarization-RLHF framework using PARL | ||
| - Use PARL parallel training | ||
|
|
||
| ### How to use | ||
|
|
||
| ### Running command | ||
| #### Install dependencies | ||
|
|
||
| ```bash | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| #### Start training | ||
| ```bash | ||
| # start xparl | ||
| xparl start --port 8811 --cpu_num 10 | ||
|
|
||
| # start training | ||
| python train.py | ||
| ``` | ||
|
|
||
| ### Code Reference | ||
|
|
||
| - Official code: [RL4LMs](https://github.com/allenai/RL4LMs) | ||
| - [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,7 @@ def get_obs_shape(observation_space, ): | |
| raise NotImplementedError(f"{observation_space} observation space is not supported") | ||
|
|
||
|
|
||
| class DictRolloutBuffer: | ||
| class DictRolloutBuffer(object): | ||
| """ | ||
| Dict Rollout buffer used in on-policy algorithms like A2C/PPO. | ||
| Extends the RolloutBuffer to use dictionary observations | ||
|
|
@@ -118,46 +118,62 @@ def reset(self): | |
| self.pos = 0 | ||
| self.full = False | ||
|
|
||
| def add( | ||
| self, | ||
| obs, | ||
| action, | ||
| reward, | ||
| episode_start, | ||
| value, | ||
| log_prob, | ||
| ): | ||
| """ | ||
| :param obs: Observation | ||
| :param action: Action | ||
| :param reward: | ||
| :param episode_start: Start of episode signal. | ||
| :param value: estimated value of the current state | ||
| following the current policy. | ||
| :param log_prob: log probability of the action | ||
| following the current policy. | ||
| """ | ||
|
|
||
| if len(log_prob.shape) == 0: | ||
| # Reshape 0-d tensor to avoid error | ||
| log_prob = log_prob.reshape(-1, 1) | ||
|
|
||
| for key in self.observations.keys(): | ||
| obs_ = np.array(obs[key]).copy() | ||
| # Reshape needed when using multiple instructors with discrete observations | ||
| # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) | ||
| if isinstance(self.observation_space.spaces[key], spaces.Discrete): | ||
| obs_ = obs_.reshape((1, ) + self.obs_shape[key]) | ||
| self.observations[key][self.pos] = obs_ | ||
|
|
||
| self.actions[self.pos] = np.array(action).copy() | ||
| self.rewards[self.pos] = np.array(reward).copy() | ||
| self.episode_starts[self.pos] = np.array(episode_start).copy() | ||
| self.values[self.pos] = value.clone().cpu().numpy().flatten() | ||
| self.log_probs[self.pos] = log_prob.clone().cpu().numpy() | ||
| self.pos += 1 | ||
| if self.pos == self.buffer_size: | ||
| self.full = True | ||
| def add_transitions(self, episode_wise_transitions, rollout_info): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't have to remove |
||
| advantages_computed = False | ||
| for ep_ix, transitions in enumerate(episode_wise_transitions): | ||
| ep_length = len(transitions) | ||
| total_reward = 0.0 | ||
| total_kl_reward = 0.0 | ||
| for transition_ix, transition in enumerate(transitions): | ||
| total_reward += transition.task_reward | ||
| total_kl_reward += transition.kl_reward | ||
| rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) | ||
| rollout_info["rollout_info/log_prob"].append(transition.log_prob) | ||
| rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) | ||
| rollout_info["rollout_info/values"].append(transition.value.numpy()) | ||
|
|
||
| # add to buffer | ||
| if not self.full: | ||
| obs = transition.observation | ||
| action = transition.action | ||
| reward = transition.total_reward | ||
| episode_start = transition.episode_start | ||
| value = transition.value | ||
| log_prob = transition.log_prob | ||
| if len(log_prob.shape) == 0: | ||
| # Reshape 0-d tensor to avoid error | ||
| log_prob = log_prob.reshape(-1, 1) | ||
|
|
||
| for key in self.observations.keys(): | ||
| obs_ = np.array(obs[key]).copy() | ||
| # Reshape needed when using multiple instructors with discrete observations | ||
| # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) | ||
| if isinstance(self.observation_space.spaces[key], spaces.Discrete): | ||
| obs_ = obs_.reshape((1, ) + self.obs_shape[key]) | ||
| self.observations[key][self.pos] = obs_ | ||
|
|
||
| self.actions[self.pos] = np.array(action).copy() | ||
| self.rewards[self.pos] = np.array(reward).copy() | ||
| self.episode_starts[self.pos] = np.array(episode_start).copy() | ||
| self.values[self.pos] = value.clone().cpu().numpy().flatten() | ||
| self.log_probs[self.pos] = log_prob.clone().cpu().numpy() | ||
| self.pos += 1 | ||
| if self.pos == self.buffer_size: | ||
| self.full = True | ||
|
|
||
| # if the buffer is full, compute advantages | ||
| if self.full and not advantages_computed: | ||
| # we fetch the last value for the last time step | ||
| # values come from the next transitions's values | ||
| next_values = (transitions[transition_ix + 1].value if | ||
| (transition_ix + 1) < ep_length else torch.tensor([0.0])) | ||
|
|
||
| self.compute_returns_and_advantage(last_values=next_values, dones=transition.done) | ||
| advantages_computed = True | ||
|
|
||
| rollout_info["rollout_info/ep_rew"].append(total_reward) | ||
| rollout_info["rollout_info/ep_lens"].append(ep_length) | ||
| rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) | ||
|
|
||
| def compute_returns_and_advantage(self, last_values, dones): | ||
| """ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.