Skip to content

Commit 0472f44

Browse files
committed
fix logprob, add filtering, temperature annealing, lr descent
1 parent 7ee4452 commit 0472f44

File tree

7 files changed

+74
-27
lines changed

7 files changed

+74
-27
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
5858

5959
self.device = get_current_device()
60+
self.lr_scheduler = None
6061

6162
def setup(self) -> None:
6263
for i in range(self.num_producers):
@@ -121,6 +122,8 @@ def loop(self) -> None:
121122
pbar.set_postfix({"loss": loss})
122123
i += 1
123124
assert len(self.buffer) == 0
125+
if self.lr_scheduler is not None:
126+
self.lr_scheduler.step()
124127
if (step + 1) % self.save_interval == 0:
125128
if self.rank == 0:
126129
print(f"Start saving policy model at step {step + 1}.")

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from coati.trainer.utils import all_reduce_mean
1616
from transformers import AutoModelForCausalLM, AutoTokenizer
1717

18+
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
1819
from colossalai.nn.optimizer import HybridAdam
1920

2021

@@ -34,10 +35,10 @@ def __init__(
3435
model_config,
3536
plugin_config,
3637
microbatch_size=1,
37-
num_generations=4,
38+
num_generations=8,
3839
use_wandb=True,
39-
generator_config=None,
40-
filter_range=None,
40+
generate_config=None,
41+
training_config={},
4142
):
4243
super().__init__(
4344
num_producers,
@@ -57,7 +58,7 @@ def __init__(
5758
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
5859
self.policy_model.train()
5960
self.policy_model.gradient_checkpointing_enable()
60-
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
61+
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
6162
self.accum_loss = torch.zeros(1, device=self.device)
6263
self.accum_reward = torch.zeros(1, device=self.device)
6364
self.accum_kl = torch.zeros(1, device=self.device)
@@ -66,6 +67,7 @@ def __init__(
6667
self.accum_advantages = torch.zeros(1, device=self.device)
6768
self.accum_response_length = torch.zeros(1, device=self.device)
6869
self.accum_count = 0
70+
self.generate_config = generate_config
6971

7072
# Reference model is initialized from policy model.
7173
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -74,7 +76,7 @@ def __init__(
7476
self.tokenizer = AutoTokenizer.from_pretrained(path)
7577
self.pad_token_id = self.tokenizer.pad_token_id
7678
self.num_generations = num_generations
77-
self.filter_range = filter_range
79+
self.filter_range = training_config.get("filter_range", None)
7880
if self.filter_range is not None:
7981
assert len(self.filter_range) == 2, "Filter range should have 2 values."
8082

@@ -92,15 +94,21 @@ def __init__(
9294
self.policy_loss_fn = PolicyLoss()
9395
self.global_step = 0
9496
if use_wandb and self.rank == 0:
95-
if "repetition_penalty" in generator_config:
96-
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
97-
else:
98-
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
97+
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
9998
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
10099

100+
self.lr_scheduler = CosineAnnealingWarmupLR(
101+
optimizer=self.optimizer,
102+
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
103+
warmup_steps=0,
104+
eta_min=0.1 * training_config.get("lr", 1e-6),
105+
)
106+
101107
def setup(self):
102108
super().setup()
103-
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
109+
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
110+
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
111+
)
104112
self.reference_model, *_ = self.booster.boost(self.reference_model)
105113

106114
def step(self, step_idx: int, **kwargs) -> Optional[float]:
@@ -133,7 +141,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
133141
attention_mask=data["attention_mask"],
134142
)["logits"]
135143
action_log_probs = calc_action_log_probs(
136-
policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
144+
policy_model_logits / self.generate_config["temperature"],
145+
data["input_ids"],
146+
num_action,
147+
self.plugin.shard_config,
137148
)
138149

139150
with torch.no_grad():
@@ -142,7 +153,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
142153
attention_mask=data["attention_mask"],
143154
)["logits"]
144155
reference_action_log_probs = calc_action_log_probs(
145-
reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
156+
reference_model_logits / self.generate_config["temperature"],
157+
data["input_ids"],
158+
num_action,
159+
self.plugin.shard_config,
146160
)
147161

148162
per_token_kl = (
@@ -161,22 +175,24 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
161175
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
162176

163177
# [batch_size, num_generations]
178+
179+
group_reward = reward.view(-1, self.num_generations)
180+
reward_mean = group_reward.mean(dim=1)
164181
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
165182
loss_mask = (
166183
None
167184
if self.filter_range is None
168-
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1])
185+
else torch.logical_and(
186+
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
187+
).repeat_interleave(self.num_generations, dim=0)
169188
)
170-
group_reward = reward.view(-1, self.num_generations)
171-
reward_mean = group_reward.mean(dim=1)
172189

173190
# [batch_size x num_generations]
174-
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
191+
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
175192
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
176193
# [batch_size x num_generations]
177194
advantages = (reward - reward_mean) / (reward_std + 1e-4)
178195

179-
# Calculate Loss
180196
loss, skip_update, _ = self.policy_loss_fn(
181197
action_log_probs,
182198
old_action_log_probs,

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,21 @@ class TransformersInferenceBackend(BaseInferenceBackend):
5353
)
5454
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
5555

56-
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
56+
def __init__(
57+
self,
58+
model_config: Dict[str, Any],
59+
generate_config: Dict[str, Any],
60+
tokenizer: PreTrainedTokenizer,
61+
num_generations: int = 8,
62+
):
5763
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
5864
model_config.update(self.FORCE_MODEL_CONFIG)
5965
path = model_config.pop("path")
6066
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
6167
self.generate_config = generate_config.copy()
6268
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
6369
self.tokenizer = tokenizer
64-
self.num_generations = 8
70+
self.num_generations = num_generations
6571

6672
@torch.no_grad()
6773
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@@ -120,7 +126,13 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
120126

121127

122128
class SGLangInferenceBackend(BaseInferenceBackend):
123-
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
129+
def __init__(
130+
self,
131+
model_config: Dict[str, Any],
132+
generate_config: Dict[str, Any],
133+
tokenizer: PreTrainedTokenizer,
134+
num_generations: int = 8,
135+
):
124136
if sgl is None:
125137
raise ImportError("sglang is not installed")
126138
path = model_config.pop("path")
@@ -175,20 +187,26 @@ class VLLMInferenceBackend(BaseInferenceBackend):
175187
)
176188
FORCE_GENERATE_CONFIG = dict(
177189
logprobs=0,
178-
n=8,
179190
)
180191

181-
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
192+
def __init__(
193+
self,
194+
model_config: Dict[str, Any],
195+
generate_config: Dict[str, Any],
196+
tokenizer: PreTrainedTokenizer,
197+
num_generations: int = 8,
198+
):
182199
if LLM is None:
183200
raise ImportError("vllm is not installed")
184201
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
185202
path = model_config.pop("path")
186203
self.llm = LLM(model=path, **model_config)
187204
generate_config = generate_config.copy()
188205
generate_config.update(self.FORCE_GENERATE_CONFIG)
206+
generate_config.update({"n": num_generations})
189207
self.generate_config = SamplingParams(**generate_config)
190208
self.tokenizer = tokenizer
191-
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
209+
self.num_generations = num_generations
192210

193211
@torch.no_grad()
194212
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def launch_distributed(
4242
plugin_config: Dict[str, Any],
4343
tokenizer_config: Optional[Dict[str, Any]] = None,
4444
inference_backend: str = "transformers",
45+
num_generations: int = 8,
4546
master_addr: str = "localhost",
4647
master_port: int = 29500,
4748
core_algo: str = "GRPO",
@@ -76,6 +77,7 @@ def launch_distributed(
7677
tokenizer_config=tokenizer_config,
7778
microbatch_size=inference_microbatch_size,
7879
backend=inference_backend,
80+
num_generations=num_generations,
7981
)
8082
procs.append(producer)
8183
generate_config_consumer = copy.deepcopy(generate_config)
@@ -99,7 +101,8 @@ def launch_distributed(
99101
plugin_config=plugin_config,
100102
microbatch_size=train_microbatch_size,
101103
generate_config=generate_config_consumer,
102-
filter_range=[0.05, 9.0],
104+
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
105+
num_generations=num_generations,
103106
)
104107
procs.append(consumer)
105108
ray.get([p.setup.remote() for p in procs])

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def loop(self) -> None:
117117
None, self.num_producers, device=self.device, group_name="sync_model"
118118
)
119119
self.load_state_dict(state_dict)
120+
# linear annealing for 1 episode, temperature from initial to 0.7
121+
if episode <= 0:
122+
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
123+
self.model.generate_config.temperature = (
124+
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
125+
)
120126

121127

122128
@ray.remote
@@ -135,6 +141,7 @@ def __init__(
135141
tokenizer_config=None,
136142
microbatch_size=1,
137143
backend="transformers",
144+
num_generations: int = 8,
138145
):
139146
super().__init__(
140147
producer_idx,
@@ -150,7 +157,7 @@ def __init__(
150157
microbatch_size,
151158
backend,
152159
)
153-
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
160+
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
154161

155162
@torch.no_grad()
156163
def rollout(self, input_ids, attention_mask, **kwargs):

applications/ColossalChat/rl_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
inference_model_config = dict(path=args.model)
2424
train_model_config = dict(path=args.model)
25-
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
25+
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
2626

2727
if args.backend == "transformers":
2828
inference_model_config.update(

colossalai/shardformer/layer/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def dist_log_prob(
387387
dtype=dtype,
388388
)
389389
else:
390-
log_prob = log_softmax(logits)
390+
log_prob = log_softmax(logits, dim=-1)
391391
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
392392

393393
return log_prob

0 commit comments

Comments
 (0)