Skip to content

Commit d20c8ff

Browse files
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ce0ec40 commit d20c8ff

39 files changed

+1993
-275
lines changed

.github/workflows/run_chatgpt_examples.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,6 @@ jobs:
6161
PRETRAINED_MODEL_PATH: ./models
6262
SFT_DATASET: ./sft_data
6363
PROMPT_DATASET: ./prompt_data
64+
PROMPT_RLVR_DATASET: ./prompt_data
6465
PREFERENCE_DATASET: ./preference_data
6566
KTO_DATASET: ./kto_data

applications/ColossalChat/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ temp/
158158
applications/ColossalChat/logs
159159
applications/ColossalChat/models
160160
applications/ColossalChat/sft_data
161+
applications/ColossalChat/kto_data
161162
applications/ColossalChat/prompt_data
162163
applications/ColossalChat/preference_data
163164
applications/ColossalChat/temp

applications/ColossalChat/coati/dataset/conversation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def setup_conversation_template(
141141
pass
142142
except ValueError as e:
143143
raise ValueError(e)
144-
if not dist.is_initialized() or dist.get_rank() == 0:
144+
if save_path is not None and (not dist.is_initialized() or dist.get_rank() == 0):
145145
os.makedirs(os.path.dirname(save_path), exist_ok=True)
146146
with open(save_path, "w", encoding="utf8") as f:
147147
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")

applications/ColossalChat/coati/dataset/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,14 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch
155155
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
156156
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
157157
"""
158+
gt_answer = [ins.get("gt_answer", None) for ins in instances]
158159
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
159160
ret = super().__call__(instances=instances)
160161
input_ids = F.pad(
161162
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
162163
)
163164
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
164-
return {"input_ids": input_ids, "attention_mask": attention_mask}
165+
return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer}
165166

166167

167168
@dataclass

applications/ColossalChat/coati/dataset/tokenization_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def tokenize_prompt(
147147
ignore_index: the ignore index when calculate loss during training
148148
max_length: the maximum context length
149149
"""
150-
151150
messages = data_point["messages"]
152151
template = deepcopy(conversation_template)
153152
template.messages = []
@@ -167,7 +166,6 @@ def tokenize_prompt(
167166
if len(template.messages) % 2 != 1:
168167
# exclude the answer if provided. keep only the prompt
169168
template.messages = template.messages[:-1]
170-
171169
# Prepare data
172170
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
173171
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
@@ -185,12 +183,21 @@ def tokenize_prompt(
185183
)
186184

187185
# `inputs_decode` can be used to check whether the tokenization method is true.
188-
return dict(
189-
input_ids=tokenized,
190-
inputs_decode=prompt,
191-
seq_length=len(tokenized),
192-
seq_category=data_point["category"] if "category" in data_point else "None",
193-
)
186+
if "gt_answer" in data_point:
187+
return dict(
188+
input_ids=tokenized,
189+
inputs_decode=prompt,
190+
seq_length=len(tokenized),
191+
seq_category=data_point["category"] if "category" in data_point else "None",
192+
gt_answer=data_point["gt_answer"],
193+
)
194+
else:
195+
return dict(
196+
input_ids=tokenized,
197+
inputs_decode=prompt,
198+
seq_length=len(tokenized),
199+
seq_category=data_point["category"] if "category" in data_point else "None",
200+
)
194201

195202

196203
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):

applications/ColossalChat/coati/experience_buffer/naive.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = T
2727
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
2828
# TODO(ver217): add prefetch
2929
self.items: List[BufferItem] = []
30+
self.rng_sequence = []
31+
self.ptr = 0
3032

3133
@torch.no_grad()
3234
def append(self, experience: Experience) -> None:
@@ -40,6 +42,9 @@ def append(self, experience: Experience) -> None:
4042
if samples_to_remove > 0:
4143
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
4244
self.items = self.items[samples_to_remove:]
45+
self.rng_sequence = [i for i in range(len(self.items))]
46+
random.shuffle(self.rng_sequence)
47+
self.ptr = 0
4348

4449
def clear(self) -> None:
4550
self.items.clear()
@@ -52,7 +57,10 @@ def sample(self) -> Experience:
5257
Returns:
5358
A batch of sampled experiences.
5459
"""
55-
items = random.sample(self.items, self.sample_batch_size)
60+
items = []
61+
for _ in range(self.sample_batch_size):
62+
self.ptr = (self.ptr + 1) % len(self.items)
63+
items.append(self.items[self.rng_sequence[self.ptr]])
5664
experience = make_experience_batch(items)
5765
if self.cpu_offload:
5866
experience.to_device(self.target_device)

0 commit comments

Comments
 (0)