Skip to content

Commit 5015300

Browse files
[feat] add microbatch forwarding (#6251)
* add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 489f215 commit 5015300

File tree

5 files changed

+113
-73
lines changed

5 files changed

+113
-73
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,7 @@ def setup(self) -> None:
6666
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
6767
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
6868

69-
plugin_config = dict(
70-
tp_size=1,
71-
pp_size=1,
72-
precision="bf16",
73-
zero_stage=1,
74-
)
69+
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
7570
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
7671
plugin_config["microbatch_size"] = self.microbatch_size
7772
plugin_config.update(self.plugin_config)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
self.accum_response_length = torch.zeros(1, device=self.device)
6969
self.accum_count = 0
7070
self.generate_config = generate_config
71+
self.training_config = training_config
7172

7273
# Reference model is initialized from policy model.
7374
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -131,40 +132,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
131132
num_action = action_mask.shape[1]
132133
old_action_log_probs = data["action_log_probs"]
133134
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
135+
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
134136

135137
need_update = (step_idx + 1) % self.num_microbatches == 0
136-
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
138+
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
139+
ctx = (
140+
nullcontext()
141+
if need_update or self.booster.plugin.zero_stage == 2
142+
else self.booster.no_sync(self.policy_model, self.optimizer)
143+
)
137144
with ctx:
138-
policy_model_logits = self.policy_model(
139-
input_ids=data["input_ids"],
140-
attention_mask=data["attention_mask"],
141-
)["logits"]
142-
action_log_probs = calc_action_log_probs(
143-
policy_model_logits / self.generate_config["temperature"],
144-
data["input_ids"],
145-
num_action,
146-
self.plugin.shard_config,
147-
)
148-
149-
with torch.no_grad():
150-
reference_model_logits = self.reference_model(
151-
input_ids=data["input_ids"],
152-
attention_mask=data["attention_mask"],
153-
)["logits"]
154-
reference_action_log_probs = calc_action_log_probs(
155-
reference_model_logits / self.generate_config["temperature"],
156-
data["input_ids"],
157-
num_action,
158-
self.plugin.shard_config,
159-
)
160-
161-
per_token_kl = (
162-
torch.exp(reference_action_log_probs - action_log_probs)
163-
- (reference_action_log_probs - action_log_probs)
164-
- 1
165-
)
166-
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
167-
168145
reward_group = self.reward_model(
169146
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
170147
)
@@ -177,6 +154,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
177154

178155
group_reward = reward.view(-1, self.num_generations)
179156
reward_mean = group_reward.mean(dim=1)
157+
# [batch_size x num_generations]
158+
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
159+
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
160+
# [batch_size x num_generations]
161+
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
180162
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
181163
loss_mask = (
182164
None
@@ -185,35 +167,82 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
185167
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
186168
).repeat_interleave(self.num_generations, dim=0)
187169
)
170+
mean_kl, mean_loss = [], []
171+
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
172+
input_ids_forward_micro_batch = data["input_ids"][
173+
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
174+
]
175+
attention_mask_forward_micro_batch = data["attention_mask"][
176+
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
177+
]
178+
action_mask_forward_micro_batch = action_mask[
179+
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
180+
]
181+
loss_mask_forward_micro_batch = (
182+
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
183+
if loss_mask is not None
184+
else None
185+
)
186+
advantages_forward_micro_batch = advantages[
187+
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
188+
]
189+
policy_model_logits = self.policy_model(
190+
input_ids=input_ids_forward_micro_batch,
191+
attention_mask=attention_mask_forward_micro_batch,
192+
).logits
193+
action_log_probs = calc_action_log_probs(
194+
policy_model_logits / self.generate_config["temperature"],
195+
input_ids_forward_micro_batch,
196+
num_action,
197+
self.plugin.shard_config,
198+
)
188199

189-
# [batch_size x num_generations]
190-
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
191-
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
192-
# [batch_size x num_generations]
193-
advantages = (reward - reward_mean) / (reward_std + 1e-4)
194-
195-
loss, skip_update, _ = self.policy_loss_fn(
196-
action_log_probs,
197-
old_action_log_probs,
198-
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
199-
per_token_kl,
200-
action_mask,
201-
loss_mask=loss_mask,
202-
)
200+
with torch.no_grad():
201+
reference_model_logits = self.reference_model(
202+
input_ids=input_ids_forward_micro_batch,
203+
attention_mask=attention_mask_forward_micro_batch,
204+
).logits
205+
reference_action_log_probs = calc_action_log_probs(
206+
reference_model_logits / self.generate_config["temperature"],
207+
input_ids_forward_micro_batch,
208+
num_action,
209+
self.plugin.shard_config,
210+
)
211+
212+
per_token_kl = (
213+
torch.exp(reference_action_log_probs - action_log_probs)
214+
- (reference_action_log_probs - action_log_probs)
215+
- 1
216+
)
217+
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
218+
action_mask_forward_micro_batch, dim=-1
219+
)
220+
221+
loss, skip_update, _ = self.policy_loss_fn(
222+
action_log_probs,
223+
old_action_log_probs,
224+
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
225+
per_token_kl,
226+
action_mask_forward_micro_batch,
227+
loss_mask=loss_mask_forward_micro_batch,
228+
)
229+
230+
if not skip_update:
231+
self.booster.backward(loss, self.optimizer)
232+
loss = all_reduce_mean(loss, self.plugin)
233+
kl = all_reduce_mean(kl.mean(), self.plugin)
234+
# Calculate accumulate value.
235+
mean_kl.append(kl.data)
236+
mean_loss.append(loss.data)
203237

204-
if not skip_update:
205-
self.booster.backward(loss, self.optimizer)
206-
loss = all_reduce_mean(loss, self.plugin)
207238
reward = all_reduce_mean(reward.mean(), self.plugin)
208-
kl = all_reduce_mean(kl.mean(), self.plugin)
209239
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
210240
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
211241
advantages = all_reduce_mean(advantages.mean(), self.plugin)
212242
response_length = all_reduce_mean(response_length.mean(), self.plugin)
213-
# Calculate accumulate value.
214-
self.accum_loss.add_(loss.data)
243+
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
244+
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
215245
self.accum_reward.add_(reward.data)
216-
self.accum_kl.add_(kl.data)
217246
self.accum_format_reward.add_(format_reward.data)
218247
self.accum_acc_reward.add_(acc_reward.data)
219248
self.accum_advantages.add_(advantages.data)

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def launch_distributed(
3434
inference_microbatch_size: int,
3535
train_batch_size: int,
3636
train_microbatch_size: int,
37+
train_minibatch_size: int,
3738
dataset_config: Dict[str, Any],
3839
dataloaders_config: Dict[str, Any],
3940
inference_model_config: Dict[str, Any],
@@ -99,9 +100,13 @@ def launch_distributed(
99100
batch_size=train_batch_size,
100101
model_config=train_model_config,
101102
plugin_config=plugin_config,
102-
microbatch_size=train_microbatch_size,
103+
microbatch_size=train_minibatch_size,
103104
generate_config=generate_config_consumer,
104-
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
105+
training_config={
106+
"filter_range": [0.05, 9.0],
107+
"lr": 1e-6,
108+
"train_microbatch_size": train_microbatch_size,
109+
},
105110
num_generations=num_generations,
106111
)
107112
procs.append(consumer)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def loop(self) -> None:
100100
if i >= num_valid_microbatches:
101101
break
102102
outputs = self.rollout(**batch)
103+
103104
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
104105
outputs["temperature"] = torch.tensor(
105106
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
@@ -116,16 +117,19 @@ def loop(self) -> None:
116117
print(
117118
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
118119
)
120+
119121
state_dict = ray_broadcast_tensor_dict(
120122
None, self.num_producers, device=self.device, group_name="sync_model"
121123
)
122124
self.load_state_dict(state_dict)
125+
del state_dict
126+
torch.cuda.empty_cache()
123127
# linear annealing for 1 episode, temperature from initial to 0.7
124128
if episode <= 0:
125129
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
126-
self.model.generate_config.temperature = (
127-
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
128-
)
130+
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
131+
"temperature"
132+
] + ratio * 0.7
129133

130134

131135
@ray.remote

applications/ColossalChat/rl_example.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,30 @@
1010
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
1111
parser.add_argument("-t", "--num-trainers", type=int, default=2)
1212
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
13+
parser.add_argument("-g", "--num-generations", type=int, default=8)
1314
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
1415
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
1516
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
16-
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
17+
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
18+
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
1719
parser.add_argument("-b", "--backend", type=str, default="transformers")
1820
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
1921
args = parser.parse_args()
2022

23+
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
24+
assert (
25+
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
26+
and args.train_microbatch_size > 0
27+
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
28+
2129
ray.init(address="local", namespace="ray-example")
2230

2331
inference_model_config = dict(path=args.model)
24-
train_model_config = dict(path=args.model)
32+
train_model_config = dict(
33+
path=args.model,
34+
# use_flash_attention_2=True,
35+
# use_cache=False
36+
)
2537
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
2638

2739
if args.backend == "transformers":
@@ -31,13 +43,6 @@
3143
torch_dtype=torch.bfloat16,
3244
)
3345
)
34-
train_model_config.update(
35-
dict(
36-
use_flash_attention_2=True,
37-
torch_dtype=torch.bfloat16,
38-
use_cache=False,
39-
)
40-
)
4146
generate_config.update(
4247
dict(
4348
max_length=1024 + 512,
@@ -78,15 +83,17 @@
7883
inference_batch_size=args.inference_batch_size,
7984
inference_microbatch_size=args.inference_microbatch_size,
8085
train_batch_size=args.train_batch_size,
86+
train_minibatch_size=args.train_minibatch_size,
8187
train_microbatch_size=args.train_microbatch_size,
8288
dataset_config={"path": args.dataset, "max_length": 300},
8389
dataloaders_config={},
8490
inference_model_config=inference_model_config,
8591
generate_config=generate_config,
92+
num_generations=args.num_generations,
8693
train_model_config=train_model_config,
8794
plugin_config={},
8895
inference_backend=args.backend,
8996
master_addr="localhost",
90-
master_port=29503,
97+
master_port=29505,
9198
core_algo=args.algo,
9299
)

0 commit comments

Comments
 (0)