Skip to content

Commit a960990

Browse files
committed
optimize pp log_softmax OOM
1 parent 0f71c79 commit a960990

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
171171
applications/ColossalChat/*.db
172172
applications/ColossalChat/stdin
173173
applications/ColossalChat/*.zip
174+
applications/ColossalChat/*.prof
175+
applications/ColossalChat/*.png

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,21 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
293293
)
294294

295295
if self.booster.plugin.stage_manager.is_last_stage():
296-
reference_model_logits = reference_model_outputs["outputs"]["logits"]
297-
reference_action_log_probs = calc_action_log_probs(
298-
reference_model_logits / self.generate_config["temperature"],
299-
input_ids_forward_micro_batch,
300-
num_action,
301-
self.plugin.shard_config,
296+
reference_action_log_probs = torch.zeros(
297+
(input_ids_forward_micro_batch.size(0), num_action),
298+
device=input_ids_forward_micro_batch.device,
302299
)
300+
for i in range(reference_action_log_probs.size(0)):
301+
# activation for log_softmax is too large if vocab size and sequence length are large
302+
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
303+
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
304+
reference_action_log_probs[i, :] += calc_action_log_probs(
305+
reference_model_outputs["outputs"]["logits"][i : i + 1]
306+
/ self.generate_config["temperature"],
307+
input_ids_forward_micro_batch[i : i + 1],
308+
num_action,
309+
self.plugin.shard_config,
310+
)[0]
303311
else:
304312
# Dummy reference logprobs for data iterator.
305313
reference_action_log_probs = None
@@ -321,12 +329,19 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
321329

322330
def _criterion(outputs, inputs):
323331
action_logits = outputs.logits
324-
action_log_probs = calc_action_log_probs(
325-
action_logits / self.generate_config["temperature"],
326-
inputs["input_ids"],
327-
num_action,
328-
self.plugin.shard_config,
332+
action_log_probs = torch.zeros(
333+
(inputs["input_ids"].size(0), num_action), device=action_logits.device
329334
)
335+
for i in range(action_log_probs.size(0)):
336+
# activation for log_softmax is too large if vocab size and sequence length are large
337+
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
338+
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
339+
action_log_probs[i, :] += calc_action_log_probs(
340+
action_logits[i : i + 1] / self.generate_config["temperature"],
341+
inputs["input_ids"][i : i + 1],
342+
num_action,
343+
self.plugin.shard_config,
344+
)[0]
330345
if "reference_action_log_probs" in inputs:
331346
per_token_kl = (
332347
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)

0 commit comments

Comments
 (0)