Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ def read(self) -> List:
prompt_length=len(prompt_tokens),
)
exp_list.append(experience)

elif self.prompt_type == PromptType.CHATPAIR:
for prompt_messages, response_messages in zip(
batch_data[self.prompt_key], batch_data[self.response_key]
):
full_messages = prompt_messages + response_messages

tokens = self.tokenizer.apply_chat_template(
full_messages, add_generation_prompt=False, return_tensors="pt"
)[0]

prompt_tokens = self.tokenizer.apply_chat_template(
prompt_messages, add_generation_prompt=True, return_tensors="pt"
)[0]

experience = Experience(
tokens=tokens,
prompt_length=len(prompt_tokens),
)
exp_list.append(experience)

elif self.prompt_type == PromptType.PLAINTEXT:
# TODO: support HF format without chat template
for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]):
Expand Down
10 changes: 8 additions & 2 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ def explore(config: Config) -> None:
def train(config: Config) -> None:
"""Run trainer."""

algo_type = config.trainer.algorithm_type
trainer = Trainer.remote(config)
ray.get(trainer.prepare.remote())

if config.trainer.sft_warmup_iteration > 0:
for step in range(config.trainer.sft_warmup_iteration):
ray.get(trainer.train_step.remote(AlgorithmType.SFT))
logger.info(f"SFT warmup step {step} finished.")

algo_type = config.trainer.algorithm_type
try:
ray.get(trainer.prepare.remote())
ray.get(trainer.train.remote(algo_type))
logger.info("Train finished.")
except Exception as e:
Expand Down
5 changes: 3 additions & 2 deletions trinity/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta):
class PromptType(CaseInsensitiveEnum):
"""Prompt Type."""

MESSAGES = "messages"
PLAINTEXT = "plaintext"
MESSAGES = "messages" # prompt+response: message list
CHATPAIR = "chatpair" # prompt: message list, response: message list
PLAINTEXT = "plaintext" # prompt: plaintext, response: plaintext


class TaskType(Enum):
Expand Down
2 changes: 1 addition & 1 deletion trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def sync_weight(self) -> None:

def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO) -> None:
self.actor_rollout_wg.set_mode(algorithm_type)
if algorithm_type.is_rft() and self.algorithm_type.is_sft():
if self.algorithm_type.is_sft() and (not algorithm_type.is_sft()):
self.sft_to_rft()
self.algorithm_type = algorithm_type

Expand Down