diff --git a/benchmark/torch/RL4LMs/README.md b/benchmark/torch/RL4LMs/README.md new file mode 100644 index 000000000..722cc3e18 --- /dev/null +++ b/benchmark/torch/RL4LMs/README.md @@ -0,0 +1,37 @@ +## Reproduce Summarization-RLHF in RL4LMs using PARL + +> Paper: [Is Reinforcement Learning (Not) for Natural Language Processing: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization](https://arxiv.org/abs/2210.01241) + +### Background + +- Summarization task in NLP: Summarization is the task of producing a shorter version + of one document that preserves most of the input's meaning. +- RLHF: The abbreviation of Reinforcement Learning with Human Feedback, which uses human knowledge to train RL algorithms. + More information is available in the Hugging Face blog [Illustrating Reinforcement Learning from Human Feedback (RLHF)](https://huggingface.co/blog/rlhf) + +### Main contribution + +- Build new Summarization-RLHF framework using PARL +- Use PARL parallel training + +### How to use + +#### Install dependencies + +```bash +pip install -r requirements.txt +``` + +#### Start training +```bash +# start xparl +xparl start --port 8811 --cpu_num 10 + +# start training +python train.py +``` + +### Code Reference + +- Official code: [RL4LMs](https://github.com/allenai/RL4LMs) +- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) diff --git a/benchmark/torch/RL4LMs/instructor.py b/benchmark/torch/RL4LMs/instructor.py new file mode 100644 index 000000000..6f5cb84bb --- /dev/null +++ b/benchmark/torch/RL4LMs/instructor.py @@ -0,0 +1,281 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections import OrderedDict +import torch +from rl4lms_utils import Observation +from gym import spaces +from gym.spaces.dict import Dict as DictSpace +from gym.spaces.discrete import Discrete +import parl +from collections import deque +import numpy as np +from rl4lms_utils import build_datapool, build_tokenizer, build_reward_fn + + +def _flatten_obs(obs, space, n_instructor=None): + if n_instructor is not None: + return OrderedDict([(k, np.stack([o[k] for o in obs]).reshape((n_instructor, -1, len(obs[0][k])))) + for k in space.spaces.keys()]) + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + + +@parl.remote_class(wait=False) +class Instructor(object): + def __init__( + self, + reward_config=None, + tokenizer_config=None, + datapool_config=None, + max_episode_length=512, + max_prompt_length=None, + terminate_on_eos=False, + context_start_token=None, + prompt_truncation_side="left", + waiting_time_idx=0, + ): + """ + Instructor who gives reward + Args: + max_episode_length (int, optional): Max steps to the model Defaults to 512. + max_prompt_length (Optional[int], optional): maximum prompt length. Defaults to None. + terminate_on_eos (bool, optional): whether to terminate on EOS. Defaults to False. + context_start_token (bool, optional): start token for the context (For Encoder-Decoder models! ) + prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") + """ + time.sleep( + waiting_time_idx * 90) # too many Instructors may cause problems if they load datasets at the same time + tokenizer = build_tokenizer(tokenizer_config) + samples = build_datapool(datapool_config, remote_train=True)["train"] + reward_function = build_reward_fn(reward_config) + self.tokenizer = tokenizer + self.reward_function = reward_function + self.max_steps = max_episode_length + self._max_text_length = (max_prompt_length if max_prompt_length else tokenizer.model_max_length) + self._terminate_on_eos = terminate_on_eos + self._context_start_token = context_start_token + self._prompt_truncation_side = prompt_truncation_side + + # set the observation and action space here + self._vocab_size = tokenizer.vocab_size + self.observation_space = DictSpace({ + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(self._max_text_length, )), + "prompt_or_input_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(self._max_text_length, )), + "context_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(self.max_steps, )), + "context_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(self.max_steps, )), + "input_encoded_pt": + spaces.Box( + low=0, + high=self._vocab_size, + shape=(self._max_text_length + self.max_steps, ), + ), + "input_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(self._max_text_length + self.max_steps, )), + }) + self.action_space = Discrete(n=self._vocab_size) + # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency + if 'mt5' in self.tokenizer.name_or_path: + n = 250112 + self.action_space = Discrete(n=n) + elif 't5' in self.tokenizer.name_or_path: + n = 32128 + self.action_space = Discrete(n=n) + self.samples_for_replaying = deque() + for sample, weight in samples: + self.samples_for_replaying.append(sample) + + # check the tokenizer and add padding tokens + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" # TBD: configure this + self.tokenizer.truncation_side = "left" # TBD: configure this + + # init tracking variables + self.__current_sample = None + self.__current_obs = None + self.__time_step = None + + def get_new_obs_and_feedback_one_step(self, action): + self.__time_step += 1 + + # previous obs + previous_obs = self.__current_obs + + # just update the context tensor and gets the new observation + self.__current_obs = self.__current_obs.update(action, self.tokenizer) + + # decide if the episode is finished or not + done = (action == self.tokenizer.eos_token_id + and self._terminate_on_eos) or (self.__time_step == self.max_steps) + + # compute reward + reward = self.reward_function( + previous_obs, + action, + self.__current_obs, + done, + self.__current_obs.meta_info, + ) + + # populate additional info + info = { + "output": self.__current_obs.context_text, + "action_history": self.__current_obs.action_history, + "reference_text": self.__current_obs.target_or_reference_texts, + "prompt_text": self.__current_obs.prompt_or_input_text, + "prev_output": previous_obs.context_text, + "meta_info": previous_obs.meta_info, + } + + if done: + # save final observation where user can get it, then reset + info["terminal_observation"] = self.__current_obs.to_dict() + observation = self.ask() + return (observation, reward, done, info) + else: + return (self.__current_obs.to_dict(), reward, done, info) + + def get_new_obs_and_feedback_sentence(self, sentence): + res = [] + for token in sentence: + one_step_res = self.get_new_obs_and_feedback_one_step(token) + res.append(one_step_res) + return res + + def ask(self, sample=None): + """ + Reset the instructor and starts a new episode + """ + # gets a new sample if not provided + if sample is None: + sample = np.random.choice(a=self.samples_for_replaying, size=min(len(self.samples_for_replaying), 1))[0] + self.__current_sample = sample + + # init the observation + self.__current_obs = Observation.init_from_sample( + sample, + self.tokenizer, + self._max_text_length, + self.max_steps, + self._prompt_truncation_side, + self._context_start_token, + sample.meta_data, + ) + + # start the time step counter + self.__time_step = 0 + + dict_observation = self.__current_obs.to_dict() + return dict_observation + + def get_obs_and_action_space(self): + return (self.observation_space, self.action_space) + + +class InstructorGroup(object): + def __init__( + self, + instructor_config=None, + tokenizer=None, + datapool_config=None, + tokenizer_config=None, + ): + self.n_instructors = instructor_config["n_instructors"] + # remote instructors need to use config to initialize due to serialization problem + instructor_kwargs = { + "reward_config": instructor_config["reward_fn"], + "tokenizer_config": tokenizer_config, + "datapool_config": datapool_config, + "max_prompt_length": instructor_config["max_prompt_length"], + "max_episode_length": instructor_config["max_episode_length"], + "terminate_on_eos": instructor_config["terminate_on_eos"], + "prompt_truncation_side": instructor_config["prompt_truncation_side"], + "context_start_token": instructor_config["context_start_token"] + } + self.tokenizer = tokenizer + self._remote_instructors = self._create_instructors(instructor_kwargs, instructor_config["parl_master_address"]) + + # due to serialization problem, build obs space and action space here + self._vocab_size = tokenizer.vocab_size + self.observation_space = DictSpace({ + # while creating rollout buffers, observations are concatenated for each key + "prompt_or_input_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(instructor_kwargs["max_prompt_length"], )), + "prompt_or_input_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(instructor_kwargs["max_prompt_length"], )), + "context_encoded_pt": + spaces.Box(low=0, high=self._vocab_size, shape=(instructor_kwargs["max_episode_length"], )), + "context_attention_mask_pt": + spaces.Box(low=0, high=1, shape=(instructor_kwargs["max_episode_length"], )), + "input_encoded_pt": + spaces.Box( + low=0, + high=self._vocab_size, + shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"], ), + ), + "input_attention_mask_pt": + spaces.Box( + low=0, + high=1, + shape=(instructor_kwargs["max_prompt_length"] + instructor_kwargs["max_episode_length"], )), + }) + self.action_space = Discrete(n=self._vocab_size) + + def ask(self): + future_object_ids = [remote_instructor.ask() for remote_instructor in self._remote_instructors] + sample_questions = [future_object.get() for future_object in future_object_ids] + # sample_questions = future_object_ids + return _flatten_obs(sample_questions, self.observation_space) + + def feedback_sentense(self, gen_output): + sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = \ + self._instructors_feedback_sentence(gen_output.step_wise_actions) + + return sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos + + def _instructors_feedback_sentence(self, all_sentences): + all_sentences = torch.stack(all_sentences).cpu().numpy().transpose(1, 0) + future_object_ids = [ + self._remote_instructors[i].get_new_obs_and_feedback_sentence(all_sentences[i]) + for i in range(self.n_instructors) + ] + + feedback_res = np.stack([future_object.get() for future_object in future_object_ids]) + + obs, rews, dones, infos = zip(*feedback_res.reshape(-1, 4)) + return _flatten_obs(obs, self.observation_space, self.n_instructors), \ + np.stack(rews).reshape(self.n_instructors, -1), np.stack(dones).reshape(self.n_instructors, -1),\ + np.stack(infos).reshape(self.n_instructors, -1) + + def _create_instructors(self, instructor_kwargs, parl_port=None): + parl.connect(parl_port, distributed_files=["./rl4lms_utils/*.py", "./*.py"]) + return [ + Instructor( + reward_config=instructor_kwargs["reward_config"], + tokenizer_config=instructor_kwargs["tokenizer_config"], + datapool_config=instructor_kwargs["datapool_config"], + max_episode_length=instructor_kwargs["max_episode_length"], + max_prompt_length=instructor_kwargs["max_prompt_length"], + terminate_on_eos=instructor_kwargs["terminate_on_eos"], + context_start_token=instructor_kwargs["context_start_token"], + prompt_truncation_side=instructor_kwargs["prompt_truncation_side"], + waiting_time_idx=idx, + ) for idx in range(self.n_instructors) + ] diff --git a/benchmark/torch/RL4LMs/requirements.txt b/benchmark/torch/RL4LMs/requirements.txt new file mode 100644 index 000000000..95ebbe5b0 --- /dev/null +++ b/benchmark/torch/RL4LMs/requirements.txt @@ -0,0 +1,11 @@ +parl>=2.1.1 +datasets==2.10.1 +torch==1.11.0 +torchvision==0.12.0 +transformers==4.18.0 +charset-normalizer==3.0.1 +gym==0.21.0 +cchardet==2.1.7 +nltk==3.7 +gem-metrics @ git+https://github.com/GEM-benchmark/GEM-metrics.git@431a8174bd6b3637e8d6118bfad2983e39e99733 +bert-score==0.3.11 diff --git a/benchmark/torch/RL4LMs/rl4lms_agent.py b/benchmark/torch/RL4LMs/rl4lms_agent.py new file mode 100644 index 000000000..91aa72b6d --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_agent.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import numpy as np +from gym import spaces +import torch +from parl.utils import logger + + +def explained_variance(y_pred, y_true): + """ + Computes fraction of variance that ypred explains about y. + Returns 1 - Var[y-ypred] / Var[y] + + interpretation: + ev=0 => might as well have predicted zero + ev=1 => perfect prediction + ev<0 => worse than just predicting zero + """ + assert y_true.ndim == 1 and y_pred.ndim == 1 + var_y = np.var(y_true) + return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + +class RL4LMsAgent(parl.Agent): + def __init__( + self, + algorithm, + n_epochs, + batch_size=64, + norm_reward=False, + ): + super(RL4LMsAgent, self).__init__(algorithm) + self.dataset = None + self.n_epochs = n_epochs + self.batch_size = batch_size + self._norm_reward = norm_reward + self._n_updates = 0 + self.device = self.alg.model.device + + def learn(self, rollout_buffer): + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + approx_kl_divs = [] + + loss = torch.tensor(0.0) + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + continue_training = True + + for batch_ix, rollout_data in enumerate(list(rollout_buffer.get(self.batch_size))): + batch_action = rollout_data.actions + if isinstance(self.alg.model.action_space, spaces.Discrete): + # Convert discrete action from float to long + batch_action = rollout_data.actions.long().flatten() + batch_obs = rollout_data.observations + batch_adv = rollout_data.advantages + batch_logprob = rollout_data.old_log_prob + batch_return = rollout_data.returns + batch_value = rollout_data.old_values + + alg_learn_info = self.alg.learn( + batch_obs=batch_obs, + batch_action=batch_action, + batch_value=batch_value, + batch_return=batch_return, + batch_logprob=batch_logprob, + batch_adv=batch_adv) + + continue_training = alg_learn_info["continue_training"] + + entropy_losses.append(alg_learn_info["entropy_losses"]) + pg_losses.append(alg_learn_info["pg_losses"]) + value_losses.append(alg_learn_info["value_losses"]) + clip_fractions.append(alg_learn_info["clip_fractions"]) + approx_kl_divs.append(alg_learn_info["approx_kl_divs"]) + if not continue_training: + break + + self._n_updates += 1 # fix the calculation of self._n_updates + if not continue_training: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_divs[-1]:.2f}") + break + + # RL4LMs' method may lead to inaccurate calculation of self._n_updates when continue_training is false + # self._n_updates += self.n_epochs + explained_var = explained_variance(rollout_buffer.values.flatten(), rollout_buffer.returns.flatten()) + + # Logs + train_info = { + "train/entropy_loss": np.mean(entropy_losses), + "train/policy_gradient_loss": np.mean(pg_losses), + "train/value_loss": np.mean(value_losses), + "train/approx_kl": np.mean(approx_kl_divs), + "train/clip_fraction": np.mean(clip_fractions), + "train/loss": loss.item(), + "train/explained_variance": explained_var + } + + if hasattr(self.alg.model, "log_std"): + # self.logger.record( + # "train/std", torch.exp(self.policy.log_std).mean().item()) + train_info["train/std"] = torch.exp(self.alg.model.log_std).mean().item() + + # self.logger.record("train/n_updates", + # self._n_updates, exclude="tensorboard") + # self.logger.record("train/clip_range", clip_range) + train_info["train/n_updates"] = self._n_updates + train_info["train/clip_param"] = self.alg.clip_param + + logger.info(train_info) + + ppo_train_info = { + "ppo/entropy_loss": np.mean(entropy_losses).item(), + "ppo/policy_gradient_loss": np.mean(pg_losses).item(), + "ppo/value_loss": np.mean(value_losses).item(), + "ppo/approx_kl": np.mean(approx_kl_divs).item(), + } + + logger.info(ppo_train_info) + + def prepare_obs_input(self, obs): + return {key: torch.as_tensor(_obs).to(self.device) for (key, _obs) in obs.items()} + + def value(self, obs): + return self.alg.value(obs) + + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization task) for + # collecting data and testing, so here use policy() instead of sample() and only need to return info + # like log_prob and gen_kwargs without action + def policy(self, obs, actions): + return self.alg.policy( + obs=obs, + actions=actions, + ) + + def ref_policy(self, obs, action): + return self.alg.ref_policy(obs, action) + + def predict( + self, + tokenizer, + dict_obs_tensor=None, + texts=None, + max_prompt_length=None, + gen_kwargs=None, + ): + obs_tensor = self.prepare_obs_input(dict_obs_tensor) + generation_inputs = self.alg.model.build_inputs(obs_tensor) + input_ids = generation_inputs.inputs + attention_mask = generation_inputs.attention_masks + return self.alg.predict( + input_ids=input_ids, + attention_mask=attention_mask, + tokenizer=tokenizer, + texts=texts, + max_prompt_length=max_prompt_length, + gen_kwargs=gen_kwargs) + + def eval_mode(self): + self.alg.eval_mode() diff --git a/benchmark/torch/RL4LMs/rl4lms_ppo.py b/benchmark/torch/RL4LMs/rl4lms_ppo.py new file mode 100644 index 000000000..fb683bf63 --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_ppo.py @@ -0,0 +1,180 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import torch +from parl.utils.utils import check_model_method + + +class RL4LMsPPO(parl.Algorithm): + def __init__( + self, + model, + clip_param=0.2, + value_loss_coef=0.5, + entropy_coef=0.0, + initial_lr=3e-4, + max_grad_norm=0.5, + use_clipped_value_loss=False, + norm_adv=True, + target_kl=None, + seed=None, + ): + # check model method + check_model_method(model, 'value', self.__class__.__name__) + check_model_method(model, 'policy', self.__class__.__name__) + + assert isinstance(clip_param, float) + assert isinstance(value_loss_coef, float) + assert isinstance(entropy_coef, float) + assert isinstance(initial_lr, float) + assert isinstance(max_grad_norm, float) + assert isinstance(use_clipped_value_loss, bool) + assert isinstance(norm_adv, bool) + + super(RL4LMsPPO, self).__init__(model=model) + self.initial_lr = initial_lr + self.clip_param = clip_param + self.norm_adv = norm_adv + self.entropy_coef = entropy_coef + self.value_loss_coef = value_loss_coef + self.max_grad_norm = max_grad_norm + self.target_kl = target_kl + self.seed = seed + self.use_clipped_value_loss = use_clipped_value_loss + + for param_group in self.model.optimizer.param_groups: + param_group["lr"] = self.initial_lr + self.optimizer = self.model.optimizer + + def learn(self, batch_obs, batch_action, batch_value, batch_return, batch_logprob, batch_adv, lr=None): + # Do a complete pass on the rollout batch + continue_training = True + learn_info = { + "entropy_losses": None, + "pg_losses": None, + "value_losses": None, + "clip_fractions": None, + "approx_kl_divs": None, + "loss": None, + "continue_training": None + } + + values, _ = self.model.value(batch_obs) + action_log_probs, entropy, _ = self.model.policy(batch_obs, batch_action) + values = values.flatten() + entropy_loss = torch.mean(entropy) + learn_info["entropy_losses"] = entropy_loss.item() + + # Normalize advantage + if self.norm_adv: + batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = torch.exp(action_log_probs - batch_logprob) + + # clipped surrogate loss + surr1 = ratio * batch_adv + surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_adv + + action_loss = -torch.min(surr1, surr2).mean() + + # Logging + learn_info["pg_losses"] = action_loss.item() + clip_fraction = torch.mean((torch.abs(ratio - 1) > self.clip_param).float()).item() + learn_info["clip_fractions"] = clip_fraction + + # clipping + # values_pred = values + if self.use_clipped_value_loss: + value_pred_clipped = batch_value + torch.clamp( + values - batch_value, + -self.clip_param, + self.clip_param, + ) + value_losses = (values - batch_return).pow(2) + value_losses_clipped = (value_pred_clipped - batch_return).pow(2) + value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean() + else: + value_loss = 0.5 * (batch_return - values).pow(2).mean() + + # Value loss using the TD(gae_lambda) target + # value_loss = F.mse_loss(batch_return, values_pred) + learn_info["value_losses"] = value_loss.item() + + loss = value_loss * self.value_loss_coef + action_loss - self.entropy_coef * entropy_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with torch.no_grad(): + log_ratio = action_log_probs - batch_logprob + approx_kl_div = torch.mean((torch.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + learn_info["approx_kl_divs"] = approx_kl_div + + learn_info["loss"] = loss + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + learn_info["continue_training"] = continue_training + return learn_info + + if lr: + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + + # Optimization step + self.optimizer.zero_grad() + loss.backward() + # Clip grad norm + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optimizer.step() + learn_info["continue_training"] = continue_training + return learn_info + + def value(self, obs): + return self.model.value(obs) + + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here policy() only needs to return info + # like log_prob and gen_kwargs without action + def policy(self, obs, actions): + return self.model.policy( + obs=obs, + actions=actions, + ) + + def ref_policy(self, obs, action): + return self.model.ref_policy(obs, action) + + def predict( + self, + tokenizer, + texts=None, + max_prompt_length=None, + input_ids=None, + attention_mask=None, + gen_kwargs=None, + ): + return self.model.predict( + input_ids=input_ids, + attention_mask=attention_mask, + tokenizer=tokenizer, + texts=texts, + max_prompt_length=max_prompt_length, + gen_kwargs=gen_kwargs) + + def eval_mode(self): + self.model.eval() diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py new file mode 100644 index 000000000..8e721573b --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data_wrapper import RefPolicyOutput, GenerationInputs, GenerationOutputs,\ + PolicyType, Sample, Observation, TransitionInfo + +from .huggingface_generation_util import override_generation_routines + +from .buffer import DictRolloutBuffer + +from .kl_controller import KLController + +from .examiner import Examiner + +from .data_pool import CNNDailyMail + +from .reward_util import RougeRewardFunction + +from .component_build_util import build_tokenizer, build_metrics, build_reward_fn,\ + build_datapool + +from .rollout_util import RolloutUtil diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py new file mode 100644 index 000000000..f66d7ad7e --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/buffer.py @@ -0,0 +1,300 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from gym import spaces +from .data_wrapper import DictRolloutBufferSamples + +try: + # Check memory used by replay buffer when possible + import psutil +except ImportError: + psutil = None + + +def get_obs_shape(observation_space, ): + """ + Get the shape of the observation (useful for the buffers). + + :param observation_space: + :return: + """ + if isinstance(observation_space, spaces.Box): + return observation_space.shape + elif isinstance(observation_space, spaces.Discrete): + # Observation is an int + return (1, ) + elif isinstance(observation_space, spaces.MultiDiscrete): + # Number of discrete features + return (int(len(observation_space.nvec)), ) + elif isinstance(observation_space, spaces.MultiBinary): + # Number of binary features + return (int(observation_space.n), ) + elif isinstance(observation_space, spaces.Dict): + return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} + + else: + raise NotImplementedError(f"{observation_space} observation space is not supported") + + +class DictRolloutBuffer(object): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + """ + + def __init__( + self, + buffer_size, + observation_space, + action_space, + device="cpu", + gae_lambda=0.95, + gamma=0.99, + ): + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) + + self.action_dim = 1 + self.pos = 0 + self.full = False + self.device = device + + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self): + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + self.observations = {} + for key, obs_input_shape in self.obs_shape.items(): + self.observations[key] = np.zeros((self.buffer_size, 1) + obs_input_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, 1, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.values = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, 1), dtype=np.float32) + self.generator_ready = False + + self.pos = 0 + self.full = False + + def add( + self, + obs, + action, + reward, + episode_start, + value, + log_prob, + ): + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple instructors with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((1, ) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def add_transitions(self, episode_wise_transitions, rollout_info): + advantages_computed = False + for ep_ix, transitions in enumerate(episode_wise_transitions): + ep_length = len(transitions) + total_reward = 0.0 + total_kl_reward = 0.0 + for transition_ix, transition in enumerate(transitions): + total_reward += transition.task_reward + total_kl_reward += transition.kl_reward + rollout_info["rollout_info/kl_div_mean"].append(transition.kl_div) + rollout_info["rollout_info/log_prob"].append(transition.log_prob) + rollout_info["rollout_info/ref_log_prob"].append(transition.ref_log_prob) + rollout_info["rollout_info/values"].append(transition.value.numpy()) + + # add to buffer + if not self.full: + self.add( + transition.observation, + transition.action, + transition.total_reward, + transition.episode_start, + transition.value, + transition.log_prob, + ) + + # if the buffer is full, compute advantages + if self.full and not advantages_computed: + # we fetch the last value for the last time step + # values come from the next transitions's values + next_values = (transitions[transition_ix + 1].value if + (transition_ix + 1) < ep_length else torch.tensor([0.0])) + + self.compute_returns_and_advantage(last_values=next_values, dones=transition.done) + advantages_computed = True + + rollout_info["rollout_info/ep_rew"].append(total_reward) + rollout_info["rollout_info/ep_lens"].append(ep_length) + rollout_info["rollout_info/ep_kl_rew"].append(total_kl_reward) + + def compute_returns_and_advantage(self, last_values, dones): + """ + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. + + Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. + + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) + + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. + + :param last_values: state value estimation for the last step (one for each instructor) + :param dones: if the last step was a terminal step (one bool for each instructor). + """ + # Convert to numpy + last_values = last_values.clone().cpu().numpy().flatten() + + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_values = last_values + else: + next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA + self.returns = self.advantages + self.values + + def swap_and_flatten(self, arr): + """ + Swap and then flatten axes 0 (buffer_size) and 1 (n_instructors) + to convert shape from [n_steps, n_instructors, ...] (when ... is the shape of the features) + to [n_steps_per_episode * n_instructors, ...] (which maintain the order) + + :param arr: + :return: + """ + shape = arr.shape + if len(shape) < 3: + shape = shape + (1, ) + return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) + + def get(self, batch_size): + assert self.full, "" + indices = np.random.permutation(self.buffer_size * 1) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * 1 + + start_idx = 0 + while start_idx < self.buffer_size * 1: + yield self._get_samples(indices[start_idx:start_idx + batch_size]) + start_idx += batch_size + + def to_torch(self, array, copy=True): + """ + Convert a numpy array to a PyTorch tensor. + Note: it copies the data by default + + :param array: + :param copy: Whether to copy or not the data + (may be useful to avoid changing things be reference) + :return: + """ + if copy: + return torch.tensor(array).to(self.device) + return torch.as_tensor(array).to(self.device) + + def _get_samples(self, batch_inds): + + return DictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) + for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + ) diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py new file mode 100644 index 000000000..1dd9c81b2 --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/component_build_util.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer +from parl.utils import logger +from .reward_util import RougeRewardFunction +from .metric_util import MetricRegistry +from .data_pool import CNNDailyMail + + +def build_tokenizer(tokenizer_config): + logger.info(f"loading tokenizer of [{tokenizer_config['model_name']}] model") + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["model_name"]) + except Exception: + logger.info(f"trying to use local_files to load tokenizer of [{tokenizer_config['model_name']}] model") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["model_name"], local_files_only=True) + if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True): + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = tokenizer_config.get("padding_side", "left") + tokenizer.truncation_side = tokenizer_config.get("truncation_side", "left") + return tokenizer + + +def build_reward_fn(reward_config): + logger.info(f"loading reward function: rouge") + reward_fn = RougeRewardFunction(rouge_type=reward_config["rouge_type"]) + return reward_fn + + +def build_metrics(metric_configs): + metrics = [ + MetricRegistry.get(metric_config["id"], metric_config.get("args", {})) for metric_config in metric_configs + ] + return metrics + + +def build_datapool(datapool_config, remote_train=False): + def _get_datapool_by_split(split): + kwargs = {"prompt_prefix": datapool_config["prompt_prefix"], "split": split} + logger.info(f"loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") + dp_split = CNNDailyMail.prepare(split=kwargs["split"], prompt_prefix=kwargs["prompt_prefix"]) + logger.info(f"finish loading split of dataset: {datapool_config['id']} -- {kwargs['split']}") + return dp_split + + train_datapool = _get_datapool_by_split("train") + + if remote_train: + samples_by_split = { + "train": [(sample, weight) for sample, weight in train_datapool], + } + return samples_by_split + + val_datapool = _get_datapool_by_split("val") + test_datapool = _get_datapool_by_split("test") + + samples_by_split = { + "train": [(sample, weight) for sample, weight in train_datapool], + "val": [sample for sample, _ in val_datapool], + "test": [sample for sample, _ in test_datapool] + } + return samples_by_split diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py new file mode 100644 index 000000000..9ae1e0f9b --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_pool.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset +from .data_wrapper import Sample +import random +from tqdm import tqdm +from nltk.tokenize import word_tokenize + + +class CNNDailyMail(object): + def __init__(self, samples): + self._samples = samples + + def __len__(self): + return len(self._samples) + + def __getitem__(self, ix): + if ix >= len(self): + raise StopIteration + sample = self._samples[ix] + return sample, 1.0 + + @classmethod + def prepare(cls, split, prompt_suffix="", prompt_prefix="", truncate_article=None, max_size=None): + split2name = {"train": "train", "val": "validation", "test": "test"} + dataset = load_dataset("cnn_dailymail", "3.0.0") + dataset_split = split2name[split] + samples = [] + for ix, item in tqdm( + enumerate(dataset[dataset_split]), desc="Tokenizing dataset", total=len(dataset[dataset_split])): + + if truncate_article is not None: + tokens = word_tokenize(item["article"]) + tokens = tokens[:truncate_article] + item["article"] = " ".join(tokens) + + sample = Sample( + id=f"{split}_{ix}", + prompt_or_input_text=prompt_prefix + item["article"] + prompt_suffix, + references=[item["highlights"]]) + samples.append(sample) + + if max_size is not None and ix == (max_size - 1): + break + + pool_instance = cls(samples) + return pool_instance + + def sample(self): + random_sample = random.choice(self._samples) + return random_sample + + def split(self, split_ratios): + start_ix = 0 + pools = [] + for ratio in split_ratios: + count = int(len(self) * ratio) + end_ix = start_ix + count + pools.append(type(self)(self._samples[start_ix:end_ix])) + start_ix = end_ix + return pools diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py new file mode 100644 index 000000000..bd92d56fb --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/data_wrapper.py @@ -0,0 +1,255 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List +from transformers import AutoTokenizer +from copy import deepcopy +from typing import NamedTuple +import torch +import numpy as np + +from typing import Any, Union + +TensorDict = Dict[Union[str, int], torch.Tensor] + + +@dataclass +class TransitionInfo(object): + observation: TensorDict + action: np.ndarray + task_reward: np.ndarray + total_reward: np.ndarray + kl_div: np.ndarray + episode_start: np.ndarray + value: torch.Tensor + log_prob: torch.Tensor + done: np.ndarray + ref_log_prob: torch.Tensor + kl_reward: np.ndarray + info: Dict[str, Any] + + +class DictRolloutBufferSamples(NamedTuple): + observations: TensorDict + actions: torch.Tensor + old_values: torch.Tensor + old_log_prob: torch.Tensor + advantages: torch.Tensor + returns: torch.Tensor + + +@dataclass(init=True) +class Sample(object): + id: str + prompt_or_input_text: str + references: List[str] + meta_data: Dict[str, Any] = None + + +class PolicyType(Enum): + CAUSAL = 0 + SEQ2SEQ = 1 + + +@dataclass +class RefPolicyOutput(object): + """ + Dataclass for the output of the method policy.get_ref_log_probs() + """ + + # ref log_probs for corresponding observation and chosen action + log_probs: torch.tensor + # cached policy activations for sequential forward passes + past_model_kwargs: torch.tensor + + +@dataclass +class GenerationInputs(object): + # prompt inputs + inputs: torch.tensor + # prompt attention masks + attention_masks: torch.tensor + + +@dataclass +class GenerationOutputs(object): + # log probs at each time step + step_wise_logprobs: List[List[torch.tensor]] + # actions at each time step + step_wise_actions: List[torch.tensor] + # generated tokens + gen_tokens: List[List[int]] + # generated texts + gen_texts: List[str] + # action masks + action_masks: List[torch.tensor] = None + + +@dataclass +class Observation(object): + # encoded input + prompt_or_input_encoded_pt: torch.tensor + # attention mask for the input + prompt_or_input_attention_mask_pt: torch.tensor + # input text + prompt_or_input_text: str + # encoded context + context_encoded_pt: torch.tensor + # attention mask for the context + context_attention_mask_pt: torch.tensor + # context text + context_text: str + # reference texts + target_or_reference_texts: List[str] + + # concatenated input + input_encoded_pt: torch.tensor + input_attention_mask_pt: torch.tensor + + # list of actions + action_history: List[str] + + # other meta info + meta_info: Dict[str, Any] + + def to_dict(self): + """ + For stable baselines (only return tensor items) + """ + dict_obs = { + "prompt_or_input_encoded_pt": self.prompt_or_input_encoded_pt.numpy().flatten(), + "prompt_or_input_attention_mask_pt": self.prompt_or_input_attention_mask_pt.numpy().flatten(), + "context_encoded_pt": self.context_encoded_pt.numpy().flatten(), + "context_attention_mask_pt": self.context_attention_mask_pt.numpy().flatten(), + "input_encoded_pt": self.input_encoded_pt.numpy().flatten(), + "input_attention_mask_pt": self.input_attention_mask_pt.numpy().flatten() + } + return dict_obs + + @staticmethod + def _concat(prompt: torch.tensor, prompt_mask: torch.tensor, context: torch.tensor, context_mask: torch.tensor, + pad_token: int): + + prompt_ = prompt[:, prompt_mask.flatten().bool().tolist()] + context_ = context[:, context_mask.flatten().bool().tolist()] + actual_size = prompt_.shape[1] + context_.shape[1] + + full_size = prompt.shape[1] + context.shape[1] + concatenated = torch.full((full_size, ), fill_value=pad_token).reshape(1, -1) + concatenated_mask = torch.zeros((1, full_size)).int() + + concatenated[:, full_size - actual_size:] = torch.cat((prompt_, context_), dim=1) + concatenated_mask[:, full_size - actual_size:] = 1 + return concatenated, concatenated_mask + + def update(self, action: int, tokenizer: AutoTokenizer): + """ + Updates the observation using the given action + """ + + # update the action history + current_action_history = deepcopy(self.action_history) + current_action_history.append(tokenizer._convert_id_to_token(action)) + + # get the current context + current_context = deepcopy(self.context_encoded_pt) + current_context_attention_mask = deepcopy(self.context_attention_mask_pt) + + # just shift the context (also the attention mask) to left by 1 + current_context[:, 0:-1] = current_context[:, 1:].clone() + current_context_attention_mask[:, 0:-1] = current_context_attention_mask[:, 1:].clone() + + # add the action always at the end (assumes left padding) + current_context[:, -1] = action + current_context_attention_mask[:, -1] = 1 + + # decode the context + context_text = tokenizer.decode(current_context.flatten(), skip_special_tokens=True) + + # concatenate and still keep the left padding + input_encoded_pt, input_attention_mask_pt = Observation._concat( + self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt, current_context, + current_context_attention_mask, tokenizer.pad_token_id) + + # and create a new observation + obs = Observation(self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt, + self.prompt_or_input_text, current_context, current_context_attention_mask, context_text, + self.target_or_reference_texts, input_encoded_pt, input_attention_mask_pt, + current_action_history, self.meta_info) + + return obs + + @classmethod + def init_from_sample(cls, + sample: Sample, + tokenizer: AutoTokenizer, + max_input_length: int, + max_context_length: int, + prompt_truncation_side: str, + context_start_token: int = None, + meta_info: Dict[str, Any] = None): + # encode the prompt text + # override truncation side for prompt + prev_truncation_side = tokenizer.truncation_side + tokenizer.truncation_side = prompt_truncation_side + prompt_outputs = tokenizer( + sample.prompt_or_input_text, + padding="max_length", + max_length=max_input_length, + return_tensors="pt", + return_attention_mask=True, + truncation=True) + tokenizer.truncation_side = prev_truncation_side + + # for seq2seq models, context should be initialized to start token if provided + if context_start_token is not None: + context_outputs = tokenizer( + "", + padding="max_length", + max_length=max_context_length, + return_tensors="pt", + return_attention_mask=True) + context_outputs.input_ids = torch.ones(1, max_context_length, dtype=torch.int32) * tokenizer.pad_token_id + context_outputs.input_ids[:, -1] = context_start_token + context_outputs.attention_mask = torch.zeros(1, max_context_length, dtype=torch.int32) + context_outputs.attention_mask[:, -1] = 1 + else: + context_outputs = tokenizer( + "", + padding="max_length", + max_length=max_context_length, + return_tensors="pt", + return_attention_mask=True) + + # concatenate + input_encoded_pt, input_attention_mask_pt = Observation._concat( + prompt_outputs.input_ids, prompt_outputs.attention_mask, context_outputs.input_ids, + context_outputs.attention_mask, tokenizer.pad_token_id) + + obs = Observation( + prompt_or_input_encoded_pt=prompt_outputs.input_ids, + prompt_or_input_attention_mask_pt=prompt_outputs.attention_mask, + prompt_or_input_text=sample.prompt_or_input_text, + context_encoded_pt=context_outputs.input_ids, + context_attention_mask_pt=context_outputs.attention_mask, + input_encoded_pt=input_encoded_pt, + input_attention_mask_pt=input_attention_mask_pt, + context_text="", + target_or_reference_texts=sample.references, + action_history=[], + meta_info=meta_info) + + return obs diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py new file mode 100644 index 000000000..698a03c53 --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/examiner.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import tqdm +from parl.utils import logger + + +# class for results evaluation +class Examiner(object): + def __init__(self, tokenizer, eval_batch_size, metrics, eval_gen_kwargs, samples_by_split, max_prompt_length): + self._tokenizer = tokenizer + self._batch_size = eval_batch_size + self._metrics = metrics + self._gen_kwargs = eval_gen_kwargs + self._samples_by_split = samples_by_split + self._max_prompt_length = max_prompt_length + + def evaluate(self, policy, sample_name_list, epoch): + for split_name in sample_name_list: + self._evaluate_on_samples(policy=policy, epoch=epoch, split_name=split_name) + + def _evaluate_on_samples( + self, + policy, + epoch, + split_name, + dt_control_token="", + ): + samples = self._samples_by_split[split_name] + # generate text by batch + all_generated_texts = [] + all_ref_texts = [] + all_prompt_texts = [] + all_meta_infos = [] + + n_samples = len(samples) + for batch in tqdm(list(self._get_batch(samples, self._batch_size)), desc="Evaluating"): + batch_generated_texts = self._generate_text(policy, self._tokenizer, batch, self._max_prompt_length, + dt_control_token) + batch_ref_texts = [sample.references for sample in batch] + batch_prompt_texts = [sample.prompt_or_input_text for sample in batch] + batch_meta_infos = [sample.meta_data for sample in batch] + all_generated_texts.extend(batch_generated_texts) + all_ref_texts.extend(batch_ref_texts) + all_prompt_texts.extend(batch_prompt_texts) + all_meta_infos.extend(batch_meta_infos) + + # compute metrics + corpus_level_metrics = {} + sample_scores_by_metric = {} + if self._metrics is not None: + for metric in self._metrics: + metric_dict = metric.compute( + all_prompt_texts, + all_generated_texts, + all_ref_texts, + all_meta_infos, + policy.get_language_model(), + split_name, + ) + + for metric_key, (sample_scores, corpus_score) in metric_dict.items(): + if sample_scores is None: + sample_scores = ["n/a"] * n_samples + corpus_level_metrics[metric_key] = corpus_score + sample_scores_by_metric[metric_key] = sample_scores + + # aggregate sample metric scores + sample_predictions_dict = [] + for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate( + zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts)): + sample_prediction = { + "split_name": + split_name, + "sample_id": + sample.id, + "prompt_text": + prompt_text, + "generated_text": + generated_text, + "ref_text": + "".join([ + f"" + ref_text + f"" for ref_ix, ref_text in enumerate(ref_texts) + ]), + } + for metric_key, sample_scores in sample_scores_by_metric.items(): + sample_prediction[metric_key] = sample_scores[ix] + sample_predictions_dict.append(sample_prediction) + + metrics_dict_ = {"epoch": epoch, "metrics": corpus_level_metrics} + + # logger + logger.info(f"{split_name} metrics: {metrics_dict_}") + + def _get_batch(self, samples, batch_size): + current_ix = 0 + n_samples = len(samples) + while current_ix < n_samples: + current_batch = samples[current_ix:current_ix + batch_size] + yield current_batch + current_ix += batch_size + + def _generate_text( + self, + policy, + tokenizer, + samples, + max_prompt_length, + dt_control_token, + ): + prompt_texts = [dt_control_token + sample.prompt_or_input_text for sample in samples] + generated_texts = policy.predict( + tokenizer, prompt_texts, max_prompt_length, gen_kwargs=self._gen_kwargs).gen_texts + return generated_texts diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py new file mode 100644 index 000000000..c72f8c82c --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/huggingface_generation_util.py @@ -0,0 +1,1183 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py + +from transformers.generation_utils import GenerationMixin +import inspect +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch import nn + +from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint +from transformers.generation_logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, + LogitsProcessorList, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) +from transformers.generation_stopping_criteria import ( + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) +from transformers.utils import ModelOutput, logging + +logger = logging.get_logger(__name__) + + +@dataclass +class SampleEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of + the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape + `(batch_size*num_return_sequences, config.vocab_size)`). + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape + `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length, + sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +class GenerationMixinWithRawScores(object): + """ + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + + The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], + if `constraints!=None` or `force_words_ids!=None`. + """ + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if (self.config.is_encoder_decoder and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError(f"`inputs`: {inputs}` were passed alongside " + f"{input_name} which is not allowed." + f"Make sure to either pass {inputs} or {input_name}=...") + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. models with `input_ids` can also make use of `inputs_embeds` + if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. Only encoder-decoder models can have non `input_ids` input format + if not self.config.is_encoder_decoder and input_name != "input_ids": + raise ValueError(f"If {input_name} is passed as model-specific keyword " + "input then model has to be an encoder-decoder and not a " + f"{self.__class__.__name__}.") + + # 5. if `inputs` is still None, try to create `input_ids` from BOS token + if inputs is None: + inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) + + return inputs, input_name, model_kwargs + + def _can_retrieve_inputs_from_name(self, inputs: Optional[torch.Tensor], name: str, + model_kwargs: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved + from name + """ + can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( + inspect.signature(self.forward).parameters.keys()) + + if can_retrieve_inputs and inputs is not None: + raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") + + return can_retrieve_inputs + + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + """ + Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. + """ + return {"input_ids": input_ids} + + def _prepare_input_ids_for_generation(self, bos_token_id: Optional[int], + encoder_outputs: Optional[ModelOutput]) -> torch.LongTensor: + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.last_hidden_state.size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_attention_mask_for_generation( + self, + inputs: torch.Tensor, + pad_token_id: int, + eos_token_id: int, + ) -> torch.LongTensor: + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ((eos_token_id is not None) and + (pad_token_id != eos_token_id)) + # Check if input is input_ids and padded -> only then is attention_mask defined + if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: + return inputs.ne(pad_token_id).long() + else: + return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) + + def _prepare_encoder_decoder_kwargs_for_generation(self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None) -> Dict[str, Any]: + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) + + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + return model_kwargs.pop("decoder_input_ids") + else: + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + decoder_start_token_id = (decoder_start_token_id + if decoder_start_token_id is not None else self.config.decoder_start_token_id) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif (hasattr(self.config, "decoder") and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None): + return self.config.decoder.decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + elif (hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None): + return self.config.decoder.bos_token_id + raise ValueError("`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation.") + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + expanded_return_idx = (torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to( + input_ids.device)) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + @staticmethod + def _update_model_kwargs_for_generation(outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False) -> Dict[str, Any]: + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + + return model_kwargs + + def _reorder_cache(self, past, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" + ) + + def _get_logits_warper( + self, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + temperature: Optional[float] = None, + num_beams: Optional[int] = None, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # init warp parameters + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + typical_p = typical_p if typical_p is not None else self.config.typical_p + temperature = temperature if temperature is not None else self.config.temperature + # instantiate warpers list + warpers = LogitsProcessorList() + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if temperature is not None and temperature != 1.0: + warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if typical_p is not None and typical_p < 1.0: + warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + return warpers + + def _get_logits_processor( + self, + repetition_penalty: float, + no_repeat_ngram_size: int, + encoder_no_repeat_ngram_size: int, + input_ids_seq_length: int, + encoder_input_ids: torch.LongTensor, + bad_words_ids: List[List[int]], + min_length: int, + max_length: int, + eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], + num_beams: int, + num_beam_groups: int, + diversity_penalty: float, + remove_invalid_values: bool, + exponential_decay_length_penalty: Tuple, + logits_processor: Optional[LogitsProcessorList], + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + processors = LogitsProcessorList() + + # init warp parameters + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + no_repeat_ngram_size = (no_repeat_ngram_size + if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size) + encoder_no_repeat_ngram_size = (encoder_no_repeat_ngram_size if encoder_no_repeat_ngram_size is not None else + self.config.encoder_no_repeat_ngram_size) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + min_length = min_length if min_length is not None else self.config.min_length + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty + forced_bos_token_id = (forced_bos_token_id + if forced_bos_token_id is not None else self.config.forced_bos_token_id) + forced_eos_token_id = (forced_eos_token_id + if forced_eos_token_id is not None else self.config.forced_eos_token_id) + remove_invalid_values = (remove_invalid_values + if remove_invalid_values is not None else self.config.remove_invalid_values) + exponential_decay_length_penalty = (exponential_decay_length_penalty + if exponential_decay_length_penalty is not None else + self.config.exponential_decay_length_penalty) + # instantiate processors list + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if diversity_penalty is not None and diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups)) + if repetition_penalty is not None and repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) + if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: + if self.config.is_encoder_decoder: + processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) + else: + raise ValueError("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture") + if bad_words_ids is not None: + processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) + if min_length is not None and eos_token_id is not None and min_length > 0: + processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) + if prefix_allowed_tokens_fn is not None: + processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) + if forced_bos_token_id is not None: + processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + if forced_eos_token_id is not None: + processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) + if remove_invalid_values is True: + processors.append(InfNanRemoveLogitsProcessor()) + if exponential_decay_length_penalty is not None: + processors.append( + ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)) + processors = self._merge_criteria_processor_list(processors, logits_processor) + return processors + + def _get_stopping_criteria(self, max_length: Optional[int], max_time: Optional[float], + stopping_criteria: Optional[StoppingCriteriaList]) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if max_length is not None: + criteria.append(MaxLengthCriteria(max_length=max_length)) + if max_time is not None: + criteria.append(MaxTimeCriteria(max_time=max_time)) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _merge_criteria_processor_list( + self, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " + f"but it has already been created with the values {default}. {default} has been created by passing the " + "corresponding arguments to generate or by the model's config default values. " + f"If you just want to change the default values of {object_type} consider passing them as arguments " + f"to `generate` instead of using a custom {object_type}.") + default_list.extend(custom_list) + return default_list + + def compute_beam_search_raw_logits( + self, + sequences: torch.Tensor, + scores: Tuple[torch.Tensor], + beam_indices: torch.Tensor, + eos_token_id: int = None, + ): + """Compute raw logits for beam search""" + + if not self.config.is_encoder_decoder: + raise NotImplementedError("Beam Search raw logits code is implemented only for enoder-decoder only models") + + # since sequences can be shorter than scores (probably due to beam search finalization) + # we always have to generate raw_logits only for generated sequences + # cut off the start tokens from generated + sequences = sequences.clone() + sequences = sequences[:, 1:] + gen_steps = sequences.shape[1] + + # align scores and beam indices according to gen_steps + # scores(gen_steps x(batch_size * num_beams) x vocab_size) + scores = scores[:gen_steps] + scores = torch.stack(scores) + _, _, vocab_size = scores.shape + + beam_indices = torch.tensor(beam_indices).T.to(scores.device) + beam_indices = beam_indices[:gen_steps, :] + batch_size = beam_indices.shape[1] + + # gen_steps x batch_size x vocab_size + beam_indices = beam_indices.unsqueeze(-1).repeat(1, 1, vocab_size) + step_wise_logits = scores.gather(dim=1, index=beam_indices) + assert step_wise_logits.shape == torch.Size((gen_steps, batch_size, vocab_size)) + + # finally convert to tuples + step_wise_logits = [(step_wise_logits[t], None) for t in range(gen_steps)] + return step_wise_logits + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), + constraints: Optional[List[Constraint]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, + **model_kwargs, + ): + r""" + + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. + - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if + `num_beams>1` and `do_sample=True`. + - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if + `num_beams>1` and `num_beam_groups>1`. + - *constrained beam-search decoding* by calling + [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or + `force_words_ids!=None`. + + + + Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as + defined in the model's config (`config.json`) which in turn defaults to the + [`~modeling_utils.PretrainedConfig`] of the model. + + + + Most of these parameters are explained in more detail in [this blog + post](https://huggingface.co/blog/how-to-generate). + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + max_length (`int`, *optional*, defaults to `model.config.max_length`): + The maximum length of the sequence to be generated. + max_new_tokens (`int`, *optional*, defaults to None): + The maximum numbers of tokens to generate, ignore the current number of tokens. Use either + `max_new_tokens` or `max_length` but not both, they serve the same purpose. + min_length (`int`, *optional*, defaults to 10): + The minimum length of the sequence to be generated. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + num_beams (`int`, *optional*, defaults to 1): + Number of beams for beam search. 1 means no beam search. + temperature (`float`, *optional*, defaults to 1.0): + The value used to module the next token probabilities. + top_k (`int`, *optional*, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher + are kept for generation. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + bos_token_id (`int`, *optional*): + The id of the *beginning-of-sequence* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size can only occur once. + encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): + If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the + `decoder_input_ids`. + bad_words_ids(`List[List[int]]`, *optional*): + List of token ids that are not allowed to be generated. In order to get the token ids of the words that + should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, + add_special_tokens=False).input_ids`. + force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): + List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple + list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, + this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), + where one can allow different forms of each word. + num_return_sequences(`int`, *optional*, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + max_time(`float`, *optional*, defaults to None): + The maximum amount of time you allow the computation to run for in seconds. generation will still + finish the current pass after allocated time has been passed. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens + that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape + as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) + decoder_start_token_id (`int`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + use_cache: (`bool`, *optional*, defaults to `True`): + Whether or not the model should use the past last key/values attentions (if applicable to the model) to + speed up decoding. + num_beam_groups (`int`, *optional*, defaults to 1): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of + beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + diversity_penalty (`float`, *optional*, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is + enabled. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and a + model's config. If a logit processor is passed that is already created with the arguments or a model's + config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + model's config. If a stopping criteria is passed that is already created with the arguments or a + model's config an error is thrown. This feature is intended for advanced users. + constraints (`List[Constraint]`, *optional*): + Custom constraints that can be added to the generation to ensure that the output will contain the use + of certain tokens as defined by `Constraint` objects, in the most sensible way possible. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + forced_bos_token_id (`int`, *optional*): + The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful + for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be + the target language token. + forced_eos_token_id (`int`, *optional*): + The id of the token to force as the last generated token when `max_length` is reached. + remove_invalid_values (`bool`, *optional*): + Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to + crash. Note that using `remove_invalid_values` can slow down generation. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + exponential_decay_length_penalty (`tuple(int, float)`, *optional*): + This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been + generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates + where penalty starts and `decay_factor` represents the factor of exponential decay + + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model + is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs + should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation_utils.SampleDecoderOnlyOutput`], + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation_utils.SampleEncoderDecoderOutput`], + + Examples: + + Greedy Decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # generate up to 30 tokens + >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] + ``` + + Multinomial Sampling: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> prompt = "Today I believe we can finally" + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> # sample up to 30 tokens + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] + ``` + + Beam-search decoding: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + + >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") + + >>> sentence = "Paris is one of the densest populated areas in Europe." + >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids + + >>> outputs = model.generate(input_ids) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] + ```""" + # 1. Set generation parameters if not already defined + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + num_beams = num_beams if num_beams is not None else self.config.num_beams + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups + do_sample = do_sample if do_sample is not None else self.config.do_sample + num_return_sequences = (num_return_sequences + if num_return_sequences is not None else self.config.num_return_sequences) + + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + if eos_token_id is None and hasattr(self.config, "decoder"): + eos_token_id = self.config.decoder.eos_token_id + + if pad_token_id is None and eos_token_id is not None: + # special case if pad_token_id is not defined + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict_in_generate = (return_dict_in_generate + if return_dict_in_generate is not None else self.config.return_dict_in_generate) + + # 2. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + batch_size = inputs_tensor.shape[0] + + # 3. Define other model kwargs + model_kwargs["output_attentions"] = output_attentions + model_kwargs["output_hidden_states"] = output_hidden_states + model_kwargs["use_cache"] = use_cache + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, pad_token_id, eos_token_id) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs_tensor, model_kwargs, + model_input_name) + + # 4. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + input_ids_seq_length = input_ids.shape[-1] + + # 5. Prepare `max_length` depending on other stopping criteria + # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` + if max_length is None and max_new_tokens is not None: + max_length = max_new_tokens + input_ids_seq_length + elif max_length is not None and max_new_tokens is not None: + # Both are set, this is odd, raise a warning + warnings.warn( + "Both `max_length` and `max_new_tokens` have been set " + f"but they serve the same purpose. `max_length` {max_length} " + f"will take priority over `max_new_tokens` {max_new_tokens}.", + UserWarning, + ) + # default to config if still None + max_length = max_length if max_length is not None else self.config.max_length + + if input_ids_seq_length >= max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " + "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." + ) + + # 6. determine generation mode + is_constraint_gen_mode = constraints is not None or force_words_ids is not None + is_greedy_gen_mode = ((num_beams == 1) and (num_beam_groups == 1) and do_sample is False + and not is_constraint_gen_mode) + is_sample_gen_mode = ((num_beams == 1) and (num_beam_groups == 1) and do_sample is True + and not is_constraint_gen_mode) + is_beam_gen_mode = ((num_beams > 1) and (num_beam_groups == 1) and do_sample is False + and not is_constraint_gen_mode) + is_beam_sample_gen_mode = ((num_beams > 1) and (num_beam_groups == 1) and do_sample is True + and not is_constraint_gen_mode) + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode + + if num_beam_groups > num_beams: + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") + if is_group_beam_gen_mode and do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`.") + + # 7. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + bad_words_ids=bad_words_ids, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, + remove_invalid_values=remove_invalid_values, + exponential_decay_length_penalty=exponential_decay_length_penalty, + logits_processor=logits_processor, + ) + + # 8. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria) + + # 9. go into different generation modes + if is_sample_gen_mode: + # 10. prepare logits warper + logits_warper = self._get_logits_warper( + top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams) + + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 12. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + else: + raise NotImplementedError + + def sample( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict_in_generate = (return_dict_in_generate + if return_dict_in_generate is not None else self.config.return_dict_in_generate) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = (model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states else None) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + cur_len = input_ids.shape[-1] + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits_raw = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits, model_inputs=model_inputs) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += ((next_token_logits_raw, next_token_scores), ) + if output_attentions: + decoder_attentions += ((outputs.decoder_attentions, ) if self.config.is_encoder_decoder else + (outputs.attentions, )) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions, ) + + if output_hidden_states: + decoder_hidden_states += ((outputs.decoder_hidden_states, ) if self.config.is_encoder_decoder else + (outputs.hidden_states, )) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + \ + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder) + cur_len = cur_len + 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return SampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + raise NotImplementedError + else: + return input_ids + + +def override_generation_routines(cls): + bases = list(cls.__bases__) + for base_ix in range(len(bases)): + if bases[base_ix] == GenerationMixin: + bases[base_ix] = GenerationMixinWithRawScores + + # recursively look up + if bases[base_ix] != object: + bases[base_ix] = override_generation_routines(bases[base_ix]) + + cls.__bases__ = tuple(bases) + return cls diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py new file mode 100644 index 000000000..81b3f4f22 --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/kl_controller.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class KLController(object): + def __init__(self, kl_coeff, target_kl=None): + self._kl_coeff = kl_coeff + self._target_kl = target_kl + + def step(self, kl_div: torch.tensor): + """ + Adapts the KL coeff + """ + if self._target_kl is not None: + diff_to_target = (kl_div - self._target_kl) / self._target_kl + e_t = torch.clip(diff_to_target, -0.2, 0.2).item() + self._kl_coeff = self._kl_coeff * (1 + 0.1 * e_t) + + @property + def kl_coeff(self): + return self._kl_coeff diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py new file mode 100644 index 000000000..fc155e051 --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/metric_util.py @@ -0,0 +1,185 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from datasets import load_metric +from gem_metrics.msttr import MSTTR +from gem_metrics.ngrams import NGramStats +from gem_metrics.texts import Predictions +from parl.utils import logger + + +class MeteorMetric(object): + def __init__(self): + super().__init__() + self._metric = load_metric("meteor") + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, + ): + + score = self._metric.compute(predictions=generated_texts, references=reference_texts)["meteor"] + + metric_dict = {"lexical/meteor": (None, score)} + return metric_dict + + +class RougeMetric(object): + def __init__(self, use_single_ref=True): + super().__init__() + self._metric = load_metric("rouge") + self._use_single_ref = use_single_ref + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, + ): + if self._use_single_ref: + # TBD: this is required for CNN/DM dataset, without this we get low scores + # TBD: needs investigation + ref_texts = [ref[0] for ref in reference_texts] + else: + ref_texts = reference_texts + + metric_results = self._metric.compute(predictions=generated_texts, references=ref_texts, use_stemmer=True) + score_keys = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + metric_dict = {} + for rouge_type in score_keys: + rouge_score = metric_results[rouge_type].mid.fmeasure + metric_dict[f"lexical/rouge_{rouge_type}"] = (None, rouge_score) + return metric_dict + + +class BERTScoreMetric(object): + def __init__(self, language): + super().__init__() + self._metric = load_metric("bertscore") + self._language = language + # since models are loaded heavily on cuda:0, use the last one to avoid memory + self._last_gpu = f"cuda:{torch.cuda.device_count() - 1}" + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, + ): + with torch.no_grad(): + metric_results = self._metric.compute( + predictions=generated_texts, + references=reference_texts, + lang=self._language, + device=self._last_gpu, + ) + bert_scores = metric_results["f1"] + corpus_level_score = np.mean(bert_scores) + metric_dict = {"semantic/bert_score": (bert_scores, corpus_level_score)} + return metric_dict + + +class BLEUMetric(object): + def __init__(self): + super().__init__() + self._metric = load_metric("bleu") + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, + ): + + tokenized_predictions = [] + tokenized_reference_texts = [] + for prediction, refs in zip(generated_texts, reference_texts): + tokenized_prediction = prediction.split() + tokenized_refs = [ref.split() for ref in refs] + tokenized_predictions.append(tokenized_prediction) + tokenized_reference_texts.append(tokenized_refs) + + try: + metric_results = self._metric.compute( + predictions=tokenized_predictions, references=tokenized_reference_texts) + bleu_score = metric_results["bleu"] + metric_dict = {"lexical/bleu": (None, bleu_score)} + return metric_dict + except Exception as e: + return {"lexical/bleu": (None, "n/a")} + + +class DiversityMetrics(object): + def __init__(self, window_size=100): + self._msttr_metric = MSTTR(window_size=window_size) + self._n_gram_metric = NGramStats() + + def compute( + self, + prompt_texts, + generated_texts, + reference_texts, + meta_infos=None, + model=None, + split_name=None, + ): + + predictions = Predictions(data={"filename": "", "values": generated_texts}) + diversity_metrics = {} + msttr_metrics = self._msttr_metric.compute(None, predictions) + n_gram_metrics = self._n_gram_metric.compute(None, predictions) + + for key, value in msttr_metrics.items(): + diversity_metrics[f"diversity_metrics/{key}"] = (None, value) + for key, value in n_gram_metrics.items(): + diversity_metrics[f"diversity_metrics/{key}"] = (None, value) + + return diversity_metrics + + +class MetricRegistry(object): + _registry = { + "meteor": MeteorMetric, + "rouge": RougeMetric, + "bert_score": BERTScoreMetric, + "bleu": BLEUMetric, + "diversity": DiversityMetrics, + } + + @classmethod + def get(cls, metric_id, kwargs): + logger.info(f"loading metric: {metric_id}") + metric_cls = cls._registry[metric_id] + metric = metric_cls(**kwargs) + return metric + + @classmethod + def add(cls, id, metric_cls): + MetricRegistry._registry[id] = metric_cls diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py new file mode 100644 index 000000000..c7c847783 --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/reward_util.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_metric + + +class RougeRewardFunction(object): + def __init__(self, rouge_type, use_single_ref=True): + super().__init__() + self._metric = load_metric("rouge") + self._rouge_type = rouge_type + + self._shaping_fn = None + self._use_single_ref = use_single_ref + + def __call__( + self, + current_observation, + action, + next_observation, + done, + meta_info=None, + ): + if done: + # TBD: considers only one reference for now + if self._use_single_ref: + references = [next_observation.target_or_reference_texts[0]] + else: + references = [next_observation.target_or_reference_texts] + predicted = [next_observation.context_text] + + metric_results = self._metric.compute(predictions=predicted, references=references, use_stemmer=True) + reward = metric_results[self._rouge_type].mid.fmeasure + if self._shaping_fn is not None: + aux_score = self._shaping_fn(current_observation, action, next_observation, done, meta_info) + reward = reward + aux_score + return reward + return 0 diff --git a/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py new file mode 100644 index 000000000..0ac5ba8fe --- /dev/null +++ b/benchmark/torch/RL4LMs/rl4lms_utils/rollout_util.py @@ -0,0 +1,197 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +from .kl_controller import KLController +from parl.utils import logger +from collections import OrderedDict +from .data_wrapper import TransitionInfo + + +def get_one_token_obs(obs, idx, space): + return OrderedDict([(k, obs[k][:, idx, :]) for k in space.spaces.keys()]) + + +def unpack_observations(obs_tensor, n_instructors): + """ + Unpacks vectorized dict observations into separate dict observations + """ + unpacked_obs = [] + keys = obs_tensor.keys() + for instructor_ix in range(n_instructors): + obs_dict = {} + for key in keys: + obs_dict[key] = obs_tensor[key][instructor_ix].reshape(1, -1).cpu() + unpacked_obs.append(obs_dict) + return unpacked_obs + + +class RolloutUtil(object): + def __init__(self, kl_args): + self._kl_controller = KLController(kl_args["coeff"], kl_args["target_kl"]) + + def collect_rollouts(self, agent, instructor_group, rollout_buffer): + # get tokenizer + tokenizer = instructor_group.tokenizer + + # Switch to eval mode both training and testing + agent.eval_mode() + + # reset rollout buffer and stats + rollout_buffer.reset() + + # start the rollout process + rollout_info = { + "rollout_info/ep_rew": [], + "rollout_info/kl_div_mean": [], + "rollout_info/ep_lens": [], + "rollout_info/ep_kl_rew": [], + "rollout_info/log_prob": [], + "rollout_info/ref_log_prob": [], + "rollout_info/values": [], + } + num_timesteps = 0 + while not rollout_buffer.full: + # start parallel episodes + current_obs = instructor_group.ask() + + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here agent uses predict() rather than sample() + gen_output = agent.predict(dict_obs_tensor=current_obs, tokenizer=tokenizer) + + # get episode state, reward, dones, infos from instructors + sentence_new_obs, sentence_rewards, sentence_dones, sentence_infos = instructor_group.feedback_sentense( + gen_output=gen_output) + + # generate batch of rollouts and add to buffer + episode_wise_transitions, run_timesteps = self._generate_transition( + gen_sentence=gen_output, + init_obs=current_obs, + agent=agent, + n_instructors=instructor_group.n_instructors, + obs_space=instructor_group.observation_space, + sentence_new_obs=sentence_new_obs, + sentence_rewards=sentence_rewards, + sentence_dones=sentence_dones, + sentence_infos=sentence_infos, + ) + num_timesteps += run_timesteps + + # now we flush all episode wise info to the 1-D buffer + # log transition and add to buffer + rollout_buffer.add_transitions(episode_wise_transitions, rollout_info) + + # aggregate rollout info + aggregated_rollout_info = {} + for key, values in rollout_info.items(): + aggregated_rollout_info[key] = np.mean(values).item() + aggregated_rollout_info[f"{key}_std"] = np.std(values).item() + aggregated_rollout_info["rollout_info/kl_coeff"] = self._kl_controller.kl_coeff + + logger.info(f"Rollout Info: {aggregated_rollout_info}") + + # adapt the KL coeff + self._kl_controller.step(torch.tensor(aggregated_rollout_info["rollout_info/kl_div_mean"])) + return num_timesteps + + def _generate_transition(self, + gen_sentence=None, + agent=None, + n_instructors=None, + obs_space=None, + sentence_new_obs=None, + sentence_rewards=None, + sentence_dones=None, + sentence_infos=None, + init_obs=None): + current_obs = init_obs + + review_times = 0 + episode_starts = np.ones((n_instructors, ), dtype=bool) + # process them one step at a time to collect rollout info + episode_wise_transitions = [[] for _ in range(n_instructors)] + ep_terminated = np.zeros((n_instructors, ), dtype=bool) + + for idx, actions_tensor in enumerate(gen_sentence.step_wise_actions): + if np.all(ep_terminated): + break + + # evaluate actions with actions from rollout + with torch.no_grad(): + # prepare here for forward of value_model, policy_model and ref_model + obs_tensor = agent.prepare_obs_input(current_obs) + + log_probs, _, _ = agent.policy(obs=obs_tensor, actions=actions_tensor) + + # sanity check + assert torch.all(torch.isfinite(log_probs)), "Infinite values in log probs" + + # get values + values, _ = agent.value(obs_tensor) + + # get reference log probs + ref_log_probs, _, _ = agent.ref_policy(obs_tensor, actions_tensor) + + # sanity check + assert torch.all(torch.isfinite(ref_log_probs)), "Infinite values in log probs" + + # compute KL rewards + kl_div = log_probs - ref_log_probs + kl_rewards = -1 * self._kl_controller.kl_coeff * kl_div + + actions = actions_tensor.cpu().numpy() + rewards = sentence_rewards[:, idx] + dones = sentence_dones[:, idx] + new_obs = get_one_token_obs(sentence_new_obs, idx, obs_space) + infos = sentence_infos[:, idx] + + review_times += n_instructors + + # compute total rewards + total_rewards = rewards + kl_rewards.cpu().numpy() + + # unpack individual observations + unpacked_obs = unpack_observations(obs_tensor, n_instructors) + + # store episode wise transitions separately + for instructor_ix in range(n_instructors): + # only if not terminated already + if not ep_terminated[instructor_ix]: + transtion = TransitionInfo( + observation=unpacked_obs[instructor_ix], + action=actions[instructor_ix], + task_reward=rewards[instructor_ix], + total_reward=total_rewards[instructor_ix], + kl_div=kl_div.cpu().numpy()[instructor_ix], + episode_start=episode_starts[instructor_ix], + value=values[instructor_ix].cpu(), + log_prob=log_probs[instructor_ix].cpu(), + done=dones[instructor_ix], + ref_log_prob=ref_log_probs[instructor_ix].cpu(), + kl_reward=kl_rewards.cpu().numpy()[instructor_ix], + info=infos[instructor_ix], + ) + + episode_wise_transitions[instructor_ix].append(transtion) + + # mark this episode to terminated if done occurs once + if dones[instructor_ix]: + ep_terminated[instructor_ix] = True + + episode_starts = np.zeros((n_instructors, ), dtype=bool) + current_obs = new_obs + + return episode_wise_transitions, review_times diff --git a/benchmark/torch/RL4LMs/seq2seq_model.py b/benchmark/torch/RL4LMs/seq2seq_model.py new file mode 100644 index 000000000..3ef73d1d2 --- /dev/null +++ b/benchmark/torch/RL4LMs/seq2seq_model.py @@ -0,0 +1,347 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from transformers import AutoModelForSeq2SeqLM +from copy import deepcopy +from torch.distributions import Categorical + +from transformers.modeling_utils import unwrap_model + +import parl +from rl4lms_utils import ( + override_generation_routines, + GenerationInputs, + GenerationOutputs, +) + + +class Seq2SeqLMModel(parl.Model): + def __init__( + self, + observation_space, + action_space, + model_name, + weight_decay=1e-6, + apply_model_parallel=True, + optimizer_class=torch.optim.AdamW, + generation_kwargs={}, + prompt_truncation_side="left", + device=None, + ): + super(Seq2SeqLMModel, self).__init__() + + self.observation_space = observation_space + self.action_space = action_space + + self.optimizer_class = optimizer_class + self.optimizer = None + self.device = device + + self._action_space = action_space + self._apply_model_parallel = apply_model_parallel + self._build_model_heads(model_name) + self._setup_optimizer(weight_decay, optimizer_class) + self._generation_kwargs = generation_kwargs + self._prompt_truncation_side = prompt_truncation_side + + def _build_model_heads(self, model_name): + self._policy_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self._policy_model.__class__ = override_generation_routines(type(self._policy_model)) + + self._value_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + self._ref_model = deepcopy(self._policy_model).eval() + + self._value_head = nn.Linear(self._value_model.config.hidden_size, 1, bias=False) + + # apply model parallel + if torch.cuda.is_available(): + if self._apply_model_parallel and self._policy_model.is_parallelizable: + self._policy_model.parallelize() + self._ref_model.parallelize() + self._value_model.parallelize() + self._value_head = self._value_head.to(self.device) + else: # else defaults to data parallel + self._policy_model = torch.nn.DataParallel(self._policy_model) + self._ref_model = torch.nn.DataParallel(self._ref_model) + self._value_model = torch.nn.DataParallel(self._value_model) + self._value_head = torch.nn.DataParallel(self._value_head.to(self.device)) + + # note: RL4LMs uses the same way (language model always does sample() to generate in summarization + # task) for collecting data and testing, so here policy() only needs to return info + # like log_prob and gen_kwargs without action + def policy(self, obs, actions): + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model(self._policy_model)._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs) + + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model(self._policy_model)._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name) + + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] + + # all set to get into auto-regressive mode + # prepare all of the model inputs for the decoder + batch_size = input_ids.shape[0] + model_inputs = unwrap_model(self._policy_model).prepare_inputs_for_generation(input_ids, **past_model_kwargs) + + # and forward pass to get next token logits + outputs = self._policy_model(**model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # get log probs + dist = Categorical(logits=next_token_logits) + log_prob = dist.log_prob(actions) + entropy = dist.entropy() + + # update the model kwargs for further generation + past_model_kwargs = unwrap_model(self._policy_model)._update_model_kwargs_for_generation( + outputs, + past_model_kwargs, + is_encoder_decoder=unwrap_model(self._policy_model).config.is_encoder_decoder, + ) + past_model_kwargs["decoder_attention_mask"] = torch.cat( + (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), + dim=-1, + ) + + return log_prob, entropy, past_model_kwargs + + def value(self, obs): + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model(self._value_model)._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs) + + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model(self._value_model)._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name) + + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] + + # all set to get into auto-regressive mode + # prepare all of the model inputs for the decoder + batch_size = input_ids.shape[0] + model_inputs = unwrap_model(self._value_model).prepare_inputs_for_generation(input_ids, **past_model_kwargs) + + # and forrward pass to get hidden states + outputs = self._value_model( + **model_inputs, output_hidden_states=True, decoder_attention_mask=decoder_attn_mask, return_dict=True) + + # get decoder's last hidden state + last_tokens_hidden = outputs.decoder_hidden_states[-1][:, -1, :].to(self.device) + values = self._value_head.forward(last_tokens_hidden) + + # update the model kwargs for further generation + past_model_kwargs = unwrap_model(self._value_model)._update_model_kwargs_for_generation( + outputs, + past_model_kwargs, + is_encoder_decoder=unwrap_model(self._value_model).config.is_encoder_decoder, + ) + past_model_kwargs["decoder_attention_mask"] = torch.cat( + (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), + dim=-1, + ) + return values, past_model_kwargs + + def evaluate_actions(self, obs, actions): + + log_prob, entropy, _ = self.policy(obs=obs, actions=actions) + values, _ = self.value(obs) + return values, log_prob, entropy + + def to(self, device): + if self._apply_model_parallel: + self._value_head = self._value_head.to(device) + return self + else: + return super().to(device) + + def ref_policy(self, obs, action): + # 1. prepare model inputs + past_model_kwargs = { + "attention_mask": obs["prompt_or_input_attention_mask_pt"], + } + inputs_tensor, model_input_name, past_model_kwargs = unwrap_model(self._ref_model)._prepare_model_inputs( + obs["prompt_or_input_encoded_pt"].int(), None, past_model_kwargs) + + # 2. prepare encoder outputs + past_model_kwargs = unwrap_model(self._ref_model)._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, past_model_kwargs, model_input_name) + + # 3. Prepare input_ids for auto-regressive generation + input_ids = obs["context_encoded_pt"].int() + decoder_attn_mask = obs["context_attention_mask_pt"] + + # all set to get into auto-regressive mode + # prepare all of the model inputs for the decoder + batch_size = input_ids.shape[0] + model_inputs = unwrap_model(self._ref_model).prepare_inputs_for_generation(input_ids, **past_model_kwargs) + + # and forward pass to get next token logits + outputs = self._ref_model(**model_inputs, decoder_attention_mask=decoder_attn_mask, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # get log probs + dist = Categorical(logits=next_token_logits) + log_prob = dist.log_prob(action) + entropy = dist.entropy() + + # update the model kwargs for further generation + past_model_kwargs = unwrap_model(self._ref_model)._update_model_kwargs_for_generation( + outputs, + past_model_kwargs, + is_encoder_decoder=unwrap_model(self._ref_model).config.is_encoder_decoder, + ) + past_model_kwargs["decoder_attention_mask"] = torch.cat( + (decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), + dim=-1, + ) + return log_prob, entropy, past_model_kwargs + + def get_policy_first_device(self): + return (self._policy_model.get_encoder().first_device if self._apply_model_parallel else self.device) + + def build_inputs(self, obs): + + generation_inputs = GenerationInputs(obs["prompt_or_input_encoded_pt"], + obs["prompt_or_input_attention_mask_pt"]) + return generation_inputs + + def get_language_model(self): + return unwrap_model(self._policy_model) + + def predict( + self, + tokenizer, + texts=None, + max_prompt_length=None, + input_ids=None, + attention_mask=None, + gen_kwargs=None, + ): + + # if it different from rollout gen kwargs + if gen_kwargs is None: + gen_kwargs = self._generation_kwargs + + # switch to eval + self._policy_model.eval() + + if (input_ids is None and attention_mask is None and texts is not None and max_prompt_length is not None): + # override truncation side for prompt + prev_truncation_side = tokenizer.truncation_side + tokenizer.truncation_side = self._prompt_truncation_side + encodings = tokenizer( + texts, + padding="max_length", + max_length=max_prompt_length, + return_tensors="pt", + return_attention_mask=True, + truncation=True, + ) + input_ids = encodings.input_ids + attention_mask = encodings.attention_mask + tokenizer.truncation_side = prev_truncation_side + + # if min_length argument is set and if policy is not a seq2seq LM (ie. causal LM) + # then it has to be adjusted to input_size + min_length + if "min_length" in gen_kwargs.keys() and not self.is_encoder_decoder(self._policy_model): + generation_kwargs_ = deepcopy(gen_kwargs) + generation_kwargs_["min_length"] = (input_ids.shape[1] + gen_kwargs["min_length"]) + else: + generation_kwargs_ = gen_kwargs + + # generate + gen_output = unwrap_model(self._policy_model).generate( + inputs=input_ids.to(self.get_policy_first_device()), + attention_mask=attention_mask.to(self.get_policy_first_device()), + return_dict_in_generate=True, + output_scores=True, + **generation_kwargs_, + ) + + # number of tokens generated + seq_length = len(gen_output["scores"]) + + # get only the generated text (excluding prompt) + gen_tokens = gen_output["sequences"][:, -seq_length:] + + # to texts + gen_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in gen_tokens.tolist()] + + # extract scores (logits) + step_wise_logprobs = [] + step_wise_actions = [] + for step, logits in enumerate(gen_output["scores"]): + raw_logits, _ = logits + actions_at_step = gen_tokens[:, step] + distribution = Categorical(logits=raw_logits) + log_probs = distribution.log_prob(actions_at_step) + step_wise_logprobs.append(log_probs) + step_wise_actions.append(actions_at_step) + + gen_output = GenerationOutputs(step_wise_logprobs, step_wise_actions, gen_tokens, gen_texts) + return gen_output + + def is_encoder_decoder(self, model): + return unwrap_model(model).config.is_encoder_decoder + + def set_training_mode(self, mode): + self.train(mode) + + def _get_constructor_parameters(self): + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + ) + + def save(self, path): + """ + Save model to a given location. + + :param path: + """ + torch.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + def _setup_optimizer( + self, + weight_decay, + optimizer_class, + ): + params = list(self.named_parameters()) + + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in params if not any(nd in n for nd in no_decay)], + "weight_decay": weight_decay, + }, + { + "params": [p for n, p in params if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + self.optimizer = optimizer_class(optimizer_grouped_parameters) diff --git a/benchmark/torch/RL4LMs/t5_ppo_config.py b/benchmark/torch/RL4LMs/t5_ppo_config.py new file mode 100644 index 000000000..6314bccbc --- /dev/null +++ b/benchmark/torch/RL4LMs/t5_ppo_config.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config = { + 'tokenizer': { + 'model_name': 't5-base', + 'padding_side': 'left', + 'truncation_side': 'left', + 'pad_token_as_eos_token': False + }, + 'datapool': { + 'id': 'cnn_daily_mail', + 'prompt_prefix': 'Summarize: ' + }, + 'instructor': { + 'parl_master_address': 'localhost:8811', + 'n_instructors': 10, + 'reward_fn': { + 'rouge_type': 'rouge1' + }, + 'max_prompt_length': 512, + 'max_episode_length': 100, + 'terminate_on_eos': True, + 'prompt_truncation_side': 'right', + 'context_start_token': 0 + }, + 'kl_div': { + 'coeff': 0.001, + 'target_kl': 0.2 + }, + 'rollout_buffer': { + 'n_steps_per_instructor': 512 # buffer length = n_steps_per_instructor * n_instructors + }, + 'agent': { + 'batch_size': 32, + 'n_epochs': 5, + 'alg': { + 'initial_lr': 0.000002, + 'entropy_coef': 0.0 + }, + 'model': { + 'model_name': 't5-base', + 'apply_model_parallel': True, + 'prompt_truncation_side': 'right', + 'generation_kwargs': { + 'do_sample': True, + 'top_k': 50, + 'min_length': 50, + 'max_new_tokens': 100 + } + } + }, + 'examiner': { + 'max_prompt_length': + 512, + 'eval_batch_size': + 100, + 'generation_kwargs': { + 'do_sample': True, + 'top_k': 0, + 'temperature': 0.7, + 'min_length': 50, + 'max_new_tokens': 100 + }, + # metric list, each (id, args) is one metric + 'metrics': [{ + 'id': 'meteor', + 'args': {} + }, { + 'id': 'rouge' + }, { + 'id': 'bleu', + 'args': {} + }, { + 'id': 'bert_score', + 'args': { + 'language': 'en' + } + }, { + 'id': 'diversity', + 'args': {} + }] + }, + 'train_evaluation': { + 'load_model': False, + 'save_model': True, + 'n_iters': 100, + 'eval_every': 10, + 'save_every': 10, + 'checkpoint_path': "./checkpoint/checkpoint_0.pth", + 'output_dir': "./checkpoint" + } +} diff --git a/benchmark/torch/RL4LMs/train.py b/benchmark/torch/RL4LMs/train.py new file mode 100644 index 000000000..7acf7b02d --- /dev/null +++ b/benchmark/torch/RL4LMs/train.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from t5_ppo_config import config +from parl.utils import logger +import torch +import time +import os + +# instructor and reward function +from instructor import InstructorGroup + +# evaluation, metrics, tokenizer & dataset +from rl4lms_utils import build_metrics, build_tokenizer, build_datapool +from rl4lms_utils import Examiner + +# rollout +from rl4lms_utils import DictRolloutBuffer, RolloutUtil + +# agent, algorithm and model +from rl4lms_ppo import RL4LMsPPO +from rl4lms_agent import RL4LMsAgent +from seq2seq_model import Seq2SeqLMModel + + +def main(config): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + tokenizer = build_tokenizer(config["tokenizer"]) + + # datapool + samples_by_split = build_datapool(config["datapool"]) + + instructor_group = InstructorGroup( + instructor_config=config["instructor"], + tokenizer=tokenizer, + tokenizer_config=config["tokenizer"], + datapool_config=config["datapool"], + ) + + agent_config = config["agent"] + model_config = agent_config["model"] + rl4lms_model = Seq2SeqLMModel( + observation_space=instructor_group.observation_space, + action_space=instructor_group.action_space, + device=device, + model_name=model_config["model_name"], + apply_model_parallel=model_config["apply_model_parallel"], + prompt_truncation_side=model_config["prompt_truncation_side"], + generation_kwargs=model_config["generation_kwargs"]) + alg_config = agent_config["alg"] + rl4lm_alg = RL4LMsPPO( + model=rl4lms_model, initial_lr=alg_config["initial_lr"], entropy_coef=alg_config["entropy_coef"]) + agent = RL4LMsAgent( + rl4lm_alg, + n_epochs=agent_config["n_epochs"], + batch_size=agent_config["batch_size"], + ) + + buffer_config = config["rollout_buffer"] + rollout_buffer = DictRolloutBuffer( + buffer_size=buffer_config["n_steps_per_instructor"] * instructor_group.n_instructors, + observation_space=instructor_group.observation_space, + action_space=instructor_group.action_space, + device=device, + ) + rollout_util = RolloutUtil(config["kl_div"]) + + train_evaluation_config = config["train_evaluation"] + n_iters = int(train_evaluation_config["n_iters"]) + n_steps_per_iter = instructor_group.n_instructors * buffer_config["n_steps_per_instructor"] + + # gen kwargs for evaluation + examiner_config = config["examiner"] + # metrics + metrics = build_metrics(examiner_config["metrics"]) + examiner = Examiner( + tokenizer=tokenizer, + eval_batch_size=examiner_config["eval_batch_size"], + max_prompt_length=examiner_config["max_prompt_length"], + eval_gen_kwargs=examiner_config["generation_kwargs"], + metrics=metrics, + samples_by_split=samples_by_split, + ) + + if train_evaluation_config["load_model"]: + logger.info(f"loading model from {train_evaluation_config['checkpoint_path']}") + rl4lms_model.load_state_dict(torch.load(train_evaluation_config["checkpoint_path"])["state_dict"]) + iter_start = 0 + examiner.evaluate(policy=agent.alg.modell, sample_name_list=["val", "test"], epoch=iter_start) + + for epoch in range(iter_start, n_iters): + print("========== BEGIN ==========") + print(f"outer epoch: {epoch} / {n_iters - 1}") + print("========== BEGIN ==========") + outer_start_time = time.time() + + num_timesteps = 0 + + while num_timesteps < n_steps_per_iter: + run_timesteps = rollout_util.collect_rollouts(agent, instructor_group, rollout_buffer) + num_timesteps += run_timesteps + agent.learn(rollout_buffer) + + outer_end_time = time.time() + print("========== END ==========") + print(f"outer epoch: {epoch} / {n_iters - 1}") + print(f"time used: {outer_end_time - outer_start_time} second(s), left time:" + f" {1.0 * (outer_end_time - outer_start_time) * (n_iters - epoch - 1) / 60 / 60} hour(s)") + print("========== END ==========") + + # save model + if train_evaluation_config['save_model'] and (epoch + 1) % train_evaluation_config["save_every"] == 0: + output_dir = train_evaluation_config['output_dir'] + if not os.path.exists(output_dir): + os.mkdir(output_dir) + rl4lms_model.save(f"{output_dir}/checkpoint_{epoch}.pth") + + # evaluate on val set in the given intervals + if (epoch + 1) % train_evaluation_config["eval_every"] == 0: + examiner.evaluate(policy=agent.alg.model, sample_name_list=["val"], epoch=epoch) + + # during training, we evaluate on VALIDATION set, and finally we evaluate on TEST set + examiner.evaluate(policy=agent.alg.model, sample_name_list=["test"], epoch=epoch) + + +if __name__ == '__main__': + logger.auto_set_dir() + + config["logging_dir"] = logger.get_dir() + config["sys_arg"] = sys.argv + + logger.info(config) + + main(config)