Skip to content

Commit f60adff

Browse files
authored
Add sft warmup before dpo (#12)
1 parent 65f23ab commit f60adff

File tree

4 files changed

+33
-5
lines changed

4 files changed

+33
-5
lines changed

trinity/buffer/reader/file_reader.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ def read(self) -> List:
7171
prompt_length=len(prompt_tokens),
7272
)
7373
exp_list.append(experience)
74+
75+
elif self.prompt_type == PromptType.CHATPAIR:
76+
for prompt_messages, response_messages in zip(
77+
batch_data[self.prompt_key], batch_data[self.response_key]
78+
):
79+
full_messages = prompt_messages + response_messages
80+
81+
tokens = self.tokenizer.apply_chat_template(
82+
full_messages, add_generation_prompt=False, return_tensors="pt"
83+
)[0]
84+
85+
prompt_tokens = self.tokenizer.apply_chat_template(
86+
prompt_messages, add_generation_prompt=True, return_tensors="pt"
87+
)[0]
88+
89+
experience = Experience(
90+
tokens=tokens,
91+
prompt_length=len(prompt_tokens),
92+
)
93+
exp_list.append(experience)
94+
7495
elif self.prompt_type == PromptType.PLAINTEXT:
7596
# TODO: support HF format without chat template
7697
for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]):

trinity/cli/launcher.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@ def explore(config: Config) -> None:
2929
def train(config: Config) -> None:
3030
"""Run trainer."""
3131

32-
algo_type = config.trainer.algorithm_type
3332
trainer = Trainer.remote(config)
33+
ray.get(trainer.prepare.remote())
34+
35+
if config.trainer.sft_warmup_iteration > 0:
36+
for step in range(config.trainer.sft_warmup_iteration):
37+
ray.get(trainer.train_step.remote(AlgorithmType.SFT))
38+
logger.info(f"SFT warmup step {step} finished.")
39+
40+
algo_type = config.trainer.algorithm_type
3441
try:
35-
ray.get(trainer.prepare.remote())
3642
ray.get(trainer.train.remote(algo_type))
3743
logger.info("Train finished.")
3844
except Exception as e:

trinity/common/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta):
2727
class PromptType(CaseInsensitiveEnum):
2828
"""Prompt Type."""
2929

30-
MESSAGES = "messages"
31-
PLAINTEXT = "plaintext"
30+
MESSAGES = "messages" # prompt+response: message list
31+
CHATPAIR = "chatpair" # prompt: message list, response: message list
32+
PLAINTEXT = "plaintext" # prompt: plaintext, response: plaintext
3233

3334

3435
class TaskType(Enum):

trinity/trainer/verl_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def sync_weight(self) -> None:
511511

512512
def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO) -> None:
513513
self.actor_rollout_wg.set_mode(algorithm_type)
514-
if algorithm_type.is_rft() and self.algorithm_type.is_sft():
514+
if self.algorithm_type.is_sft() and (not algorithm_type.is_sft()):
515515
self.sft_to_rft()
516516
self.algorithm_type = algorithm_type
517517

0 commit comments

Comments
 (0)