Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a1a4c4b
init file using files from RL4LMS
dwyzzy Mar 1, 2023
e706ed4
benchmark of RL4LMs v0.0
dwyzzy Mar 2, 2023
f816028
benchmark of RL4LMs v0.1
dwyzzy Mar 6, 2023
0293734
fix pg reward bug, remove no use warmstartup
dwyzzy Mar 6, 2023
02efdd9
merge models and buffers, add README.md
dwyzzy Mar 6, 2023
89c4efb
simplified code v0.0
dwyzzy Mar 7, 2023
0b69359
remove distribution_wrapper.py and sample_util.py
dwyzzy Mar 7, 2023
23735cb
remove EvaluateActionsOutput, ValueOutput and PolicyOutput
dwyzzy Mar 7, 2023
bbdd102
use Reviewer and ReviewerGroup instead of Env
dwyzzy Mar 8, 2023
a9aef6b
use Reviewer and ReviewerGroup instead of Env (parl parallel)
dwyzzy Mar 9, 2023
bf3c625
use Reviewer and ReviewerGroup instead of Env (parl parallel version)
dwyzzy Mar 9, 2023
d452685
review using sentence (parl parallel version)
dwyzzy Mar 10, 2023
b943f1c
remove some '**config' and change rollout util
dwyzzy Mar 10, 2023
086ce6f
use instructor instead of reviewer, add examiner
dwyzzy Mar 13, 2023
3acf2c3
add requirements.txt
dwyzzy Mar 13, 2023
090b190
change code style
dwyzzy Mar 13, 2023
78f44b8
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay Mar 13, 2023
b66f07e
change train.py style
dwyzzy Mar 13, 2023
0d8af33
Merge remote-tracking branch 'rl4lm_parl/sentence_review_summarizatio…
dwyzzy Mar 13, 2023
337ac75
change style
dwyzzy Mar 13, 2023
d0ced44
change style
dwyzzy Mar 13, 2023
151fcea
change code style(add copyright)
dwyzzy Mar 14, 2023
f91d2c9
bring for-batch-rollout loop out of rl4lms_ppo
dwyzzy Mar 14, 2023
dc1d835
change name of policy/value , obs-preprocess and add-to-buffer
dwyzzy Mar 15, 2023
a23e8fe
change config structure
dwyzzy Mar 15, 2023
c2be52f
change ppo code style according to parl ppo
dwyzzy Mar 15, 2023
b34ea18
yapf code style
dwyzzy Mar 15, 2023
02c8956
change code for PARL-RL4LMs summarization version 0.1
dwyzzy Mar 16, 2023
760cc9d
change code style of PARL-RL4LMs summarization version 0.1
dwyzzy Mar 16, 2023
1770e45
change unreasonable name to n_steps_per_instructor in config
dwyzzy Mar 17, 2023
b9c3e5c
add object for all classes, adjust add-to-buffer structure
dwyzzy Mar 20, 2023
59e02fa
change t5_ppo_config and README
dwyzzy Mar 20, 2023
4cd67e2
yapf code style
dwyzzy Mar 20, 2023
6af0e40
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay Mar 31, 2023
da94226
change buffer add(), add save/load
dwyzzy Apr 3, 2023
68ec090
yapf code style
dwyzzy Apr 3, 2023
2b82da6
Merge remote-tracking branch 'rl4lm_parl/sentence_review_summarizatio…
dwyzzy Apr 3, 2023
21e99e8
evaluate at beginning
dwyzzy Apr 3, 2023
1704e4f
Merge branch 'develop' into sentence_review_summarization
TomorrowIsAnOtherDay Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions benchmark/torch/RL4LMs/README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
## Reproduce (Reconfiguration) Summarization in RL4LMs using PARL
## Reproduce Summarization-RLHF in RL4LMs using PARL

> Paper: [Is Reinforcement Learning (Not) for Natural Language Processing: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization](https://arxiv.org/abs/2210.01241)
>
> Official code: [RL4LMs](https://github.com/allenai/RL4LMs)
>
> Other code referenced: [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)

### Background

- Summarization task in NLP: Summarization is the task of producing a shorter version
of one document that preserves most of the input's meaning.
- RLHF: The abbreviation of Reinforcement Learning with Human Feedback, which uses human knowledge to train RL algorithms.
More information is available in the Hugging Face blog [Illustrating Reinforcement Learning from Human Feedback (RLHF)](https://huggingface.co/blog/rlhf)

### Main contribution

- Change from **\{ trainer: \{ ppo: \{ env, rollout_buffer, policy/model \} \} \}** to
**\{trainer: \{env, rollout_buffer, agent: \{ ppo: \{ model \} \} \} \}** according to PARL architecture.
- Use Parl parallel Training
- Build new Summarization-RLHF framework using PARL
- Use PARL parallel training

### How to use

### Running command
#### Install dependencies

```bash
pip install -r requirements.txt
```

#### Start training
```bash
# start xparl
xparl start --port 8811 --cpu_num 10

# start training
python train.py
```

### Code Reference

- Official code: [RL4LMs](https://github.com/allenai/RL4LMs)
- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)
31 changes: 25 additions & 6 deletions benchmark/torch/RL4LMs/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from collections import OrderedDict
import torch
from rl4lms_utils import Observation
Expand All @@ -32,7 +32,7 @@ def _flatten_obs(obs, space, n_instructor=None):


@parl.remote_class(wait=False)
class Instructor:
class Instructor(object):
def __init__(
self,
reward_config=None,
Expand All @@ -43,6 +43,7 @@ def __init__(
terminate_on_eos=False,
context_start_token=None,
prompt_truncation_side="left",
waiting_time_idx=0,
):
"""
Instructor who gives reward
Expand All @@ -53,6 +54,8 @@ def __init__(
context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! )
prompt_truncation_side (str): truncation side for prompt text (Defaults to "left")
"""
time.sleep(
waiting_time_idx * 90) # too many Instructors may cause problems if they load datasets at the same time
tokenizer = build_tokenizer(tokenizer_config)
samples = build_datapool(datapool_config, remote_train=True)["train"]
reward_function = build_reward_fn(reward_config)
Expand Down Expand Up @@ -185,7 +188,7 @@ def get_obs_and_action_space(self):
return (self.observation_space, self.action_space)


class InstructorGroup:
class InstructorGroup(object):
def __init__(
self,
instructor_config=None,
Expand All @@ -198,9 +201,13 @@ def __init__(
instructor_kwargs = {
"reward_config": instructor_config["reward_fn"],
"tokenizer_config": tokenizer_config,
"datapool_config": datapool_config
"datapool_config": datapool_config,
"max_prompt_length": instructor_config["max_prompt_length"],
"max_episode_length": instructor_config["max_episode_length"],
"terminate_on_eos": instructor_config["terminate_on_eos"],
"prompt_truncation_side": instructor_config["prompt_truncation_side"],
"context_start_token": instructor_config["context_start_token"]
}
instructor_kwargs = {**instructor_kwargs, **instructor_config.get("args", {})}
self.tokenizer = tokenizer
self._remote_instructors = self._create_instructors(instructor_kwargs, instructor_config["parl_master_address"])

Expand Down Expand Up @@ -258,4 +265,16 @@ def _instructors_feedback_sentence(self, all_sentences):

def _create_instructors(self, instructor_kwargs, parl_port=None):
parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"])
return [Instructor(**instructor_kwargs) for _ in range(self.n_instructors)]
return [
Instructor(
reward_config=instructor_kwargs["reward_config"],
tokenizer_config=instructor_kwargs["tokenizer_config"],
datapool_config=instructor_kwargs["datapool_config"],
max_episode_length=instructor_kwargs["max_episode_length"],
max_prompt_length=instructor_kwargs["max_prompt_length"],
terminate_on_eos=instructor_kwargs["terminate_on_eos"],
context_start_token=instructor_kwargs["context_start_token"],
prompt_truncation_side=instructor_kwargs["prompt_truncation_side"],
waiting_time_idx=idx,
) for idx in range(self.n_instructors)
]
24 changes: 13 additions & 11 deletions benchmark/torch/RL4LMs/rl4lms_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ def learn(self, rollout_buffer):
batch_return = rollout_data.returns
batch_value = rollout_data.old_values

continue_training, alg_learn_info = self.alg.learn(
alg_learn_info = self.alg.learn(
batch_obs=batch_obs,
batch_action=batch_action,
batch_value=batch_value,
batch_return=batch_return,
batch_logprob=batch_logprob,
batch_adv=batch_adv)

continue_training = alg_learn_info["continue_training"]

entropy_losses.append(alg_learn_info["entropy_losses"])
pg_losses.append(alg_learn_info["pg_losses"])
value_losses.append(alg_learn_info["value_losses"])
Expand All @@ -89,12 +91,13 @@ def learn(self, rollout_buffer):
if not continue_training:
break

self._n_updates += 1 # according to stable-baseline3
self._n_updates += 1 # fix the calculation of self._n_updates
if not continue_training:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}")
break

# self._n_updates += self.n_epochs # change original RL4LMs code
# RL4LMs' method may lead to inaccurate calculation of self._n_updates when continue_training is false
# self._n_updates += self.n_epochs
explained_var = explained_variance(rollout_buffer.values.flatten(), rollout_buffer.returns.flatten())

# Logs
Expand Down Expand Up @@ -130,10 +133,6 @@ def learn(self, rollout_buffer):

logger.info(ppo_train_info)

def get_inputs_for_generation(self, dict_obs_tensor):
obs_tensor = self.prepare_obs_input(dict_obs_tensor)
return self.alg.model.get_inputs_for_generation(obs_tensor)

def prepare_obs_input(self, obs):
return {key: torch.as_tensor(_obs).to(self.device) for (key, _obs) in obs.items()}

Expand All @@ -149,18 +148,21 @@ def policy(self, obs, actions):
actions=actions,
)

def get_log_probs_ref_model(self, obs, action):
return self.alg.get_log_probs_ref_model(obs, action)
def ref_policy(self, obs, action):
return self.alg.ref_policy(obs, action)

def predict(
self,
tokenizer,
dict_obs_tensor=None,
texts=None,
max_prompt_length=None,
input_ids=None,
attention_mask=None,
gen_kwargs=None,
):
obs_tensor = self.prepare_obs_input(dict_obs_tensor)
generation_inputs = self.alg.model.build_inputs(obs_tensor)
input_ids = generation_inputs.inputs
attention_mask = generation_inputs.attention_masks
return self.alg.predict(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
14 changes: 8 additions & 6 deletions benchmark/torch/RL4LMs/rl4lms_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro
"value_losses": None,
"clip_fractions": None,
"approx_kl_divs": None,
"loss": None
"loss": None,
"continue_training": None
}

values, _ = self.model.value(batch_obs)
Expand Down Expand Up @@ -127,7 +128,8 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro

if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
return continue_training, learn_info
learn_info["continue_training"] = continue_training
return learn_info

if lr:
for param_group in self.optimizer.param_groups:
Expand All @@ -139,8 +141,8 @@ def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logpro
# Clip grad norm
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()

return continue_training, learn_info
learn_info["continue_training"] = continue_training
return learn_info

def value(self, obs):
return self.model.value(obs)
Expand All @@ -154,8 +156,8 @@ def policy(self, obs, actions):
actions=actions,
)

def get_log_probs_ref_model(self, obs, action):
return self.model.get_log_probs_ref_model(obs, action)
def ref_policy(self, obs, action):
return self.model.ref_policy(obs, action)

def predict(
self,
Expand Down
98 changes: 57 additions & 41 deletions benchmark/torch/RL4LMs/rl4lms_utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_obs_shape(observation_space, ):
raise NotImplementedError(f"{observation_space} observation space is not supported")


class DictRolloutBuffer:
class DictRolloutBuffer(object):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
Expand Down Expand Up @@ -118,46 +118,62 @@ def reset(self):
self.pos = 0
self.full = False

def add(
self,
obs,
action,
reward,
episode_start,
value,
log_prob,
):
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""

if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)

for key in self.observations.keys():
obs_ = np.array(obs[key]).copy()
# Reshape needed when using multiple instructors with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((1, ) + self.obs_shape[key])
self.observations[key][self.pos] = obs_

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def add_transitions(self, episode_wise_transitions, rollout_info):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't have to remove add method, it abstracts the process of adding a single transitions

advantages_computed = False
for ep_ix, transitions in enumerate(episode_wise_transitions):
ep_length = len(transitions)
total_reward = 0.0
total_kl_reward = 0.0
for transition_ix, transition in enumerate(transitions):
total_reward += transition.task_reward
total_kl_reward += transition.kl_reward
rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div)
rollout_info["rollout_info/log_prob"].append(transition.log_prob)
rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob)
rollout_info["rollout_info/values"].append(transition.value.numpy())

# add to buffer
if not self.full:
obs = transition.observation
action = transition.action
reward = transition.total_reward
episode_start = transition.episode_start
value = transition.value
log_prob = transition.log_prob
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)

for key in self.observations.keys():
obs_ = np.array(obs[key]).copy()
# Reshape needed when using multiple instructors with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((1, ) + self.obs_shape[key])
self.observations[key][self.pos] = obs_

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True

# if the buffer is full, compute advantages
if self.full and not advantages_computed:
# we fetch the last value for the last time step
# values come from the next transitions's values
next_values = (transitions[transition_ix + 1].value if
(transition_ix + 1) < ep_length else torch.tensor([0.0]))

self.compute_returns_and_advantage(last_values=next_values, dones=transition.done)
advantages_computed = True

rollout_info["rollout_info/ep_rew"].append(total_reward)
rollout_info["rollout_info/ep_lens"].append(ep_length)
rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward)

def compute_returns_and_advantage(self, last_values, dones):
"""
Expand Down
7 changes: 3 additions & 4 deletions benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def build_tokenizer(tokenizer_config):

def build_reward_fn(reward_config):
logger.info(f"loading reward function: rouge")
reward_fn = RougeRewardFunction(**reward_config.get("args", {}))
reward_fn = RougeRewardFunction(rouge_type=reward_config["rouge_type"])
return reward_fn


Expand All @@ -48,10 +48,9 @@ def build_metrics(metric_configs):

def build_datapool(datapool_config, remote_train=False):
def _get_datapool_by_split(split):
kwargs = datapool_config.get("args", {})
kwargs["split"] = split
kwargs = {"prompt_prefix": datapool_config["prompt_prefix"], "split": split}
logger.info(f"loading split of dataset: {datapool_config['id']} -- {kwargs['split']}")
dp_split = CNNDailyMail.prepare(**kwargs)
dp_split = CNNDailyMail.prepare(split=kwargs["split"], prompt_prefix=kwargs["prompt_prefix"])
logger.info(f"finish loading split of dataset: {datapool_config['id']} -- {kwargs['split']}")
return dp_split

Expand Down
2 changes: 1 addition & 1 deletion benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from nltk.tokenize import word_tokenize


class CNNDailyMail:
class CNNDailyMail(object):
def __init__(self, samples):
self._samples = samples

Expand Down
Loading