diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 184d6bbf86..f2e25cee31 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -71,6 +71,27 @@ def read(self) -> List: prompt_length=len(prompt_tokens), ) exp_list.append(experience) + + elif self.prompt_type == PromptType.CHATPAIR: + for prompt_messages, response_messages in zip( + batch_data[self.prompt_key], batch_data[self.response_key] + ): + full_messages = prompt_messages + response_messages + + tokens = self.tokenizer.apply_chat_template( + full_messages, add_generation_prompt=False, return_tensors="pt" + )[0] + + prompt_tokens = self.tokenizer.apply_chat_template( + prompt_messages, add_generation_prompt=True, return_tensors="pt" + )[0] + + experience = Experience( + tokens=tokens, + prompt_length=len(prompt_tokens), + ) + exp_list.append(experience) + elif self.prompt_type == PromptType.PLAINTEXT: # TODO: support HF format without chat template for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]): diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 7ff742c94a..2e980632e5 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -29,10 +29,16 @@ def explore(config: Config) -> None: def train(config: Config) -> None: """Run trainer.""" - algo_type = config.trainer.algorithm_type trainer = Trainer.remote(config) + ray.get(trainer.prepare.remote()) + + if config.trainer.sft_warmup_iteration > 0: + for step in range(config.trainer.sft_warmup_iteration): + ray.get(trainer.train_step.remote(AlgorithmType.SFT)) + logger.info(f"SFT warmup step {step} finished.") + + algo_type = config.trainer.algorithm_type try: - ray.get(trainer.prepare.remote()) ray.get(trainer.train.remote(algo_type)) logger.info("Train finished.") except Exception as e: diff --git a/trinity/common/constants.py b/trinity/common/constants.py index b72e945f63..154eb360d2 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -27,8 +27,9 @@ class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta): class PromptType(CaseInsensitiveEnum): """Prompt Type.""" - MESSAGES = "messages" - PLAINTEXT = "plaintext" + MESSAGES = "messages" # prompt+response: message list + CHATPAIR = "chatpair" # prompt: message list, response: message list + PLAINTEXT = "plaintext" # prompt: plaintext, response: plaintext class TaskType(Enum): diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 8d64f060c9..03c60cf2bb 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -511,7 +511,7 @@ def sync_weight(self) -> None: def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO) -> None: self.actor_rollout_wg.set_mode(algorithm_type) - if algorithm_type.is_rft() and self.algorithm_type.is_sft(): + if self.algorithm_type.is_sft() and (not algorithm_type.is_sft()): self.sft_to_rft() self.algorithm_type = algorithm_type