Skip to content

Commit 30a6859

Browse files
committed
optimize pp log_softmax OOM
1 parent 0e69b98 commit 30a6859

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ applications/ColossalChat/*.txt
171171
applications/ColossalChat/*.db
172172
applications/ColossalChat/stdin
173173
applications/ColossalChat/*.zip
174+
applications/ColossalChat/*.prof
175+
applications/ColossalChat/*.png

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,21 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280280
)
281281

282282
if self.booster.plugin.stage_manager.is_last_stage():
283-
reference_model_logits = reference_model_outputs["outputs"]["logits"]
284-
reference_action_log_probs = calc_action_log_probs(
285-
reference_model_logits / self.generate_config["temperature"],
286-
input_ids_forward_micro_batch,
287-
num_action,
288-
self.plugin.shard_config,
283+
reference_action_log_probs = torch.zeros(
284+
(input_ids_forward_micro_batch.size(0), num_action),
285+
device=input_ids_forward_micro_batch.device,
289286
)
287+
for i in range(reference_action_log_probs.size(0)):
288+
# activation for log_softmax is too large if vocab size and sequence length are large
289+
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
290+
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
291+
reference_action_log_probs[i, :] += calc_action_log_probs(
292+
reference_model_outputs["outputs"]["logits"][i : i + 1]
293+
/ self.generate_config["temperature"],
294+
input_ids_forward_micro_batch[i : i + 1],
295+
num_action,
296+
self.plugin.shard_config,
297+
)[0]
290298
else:
291299
# Dummy reference logprobs for data iterator.
292300
reference_action_log_probs = None
@@ -308,12 +316,19 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308316

309317
def _criterion(outputs, inputs):
310318
action_logits = outputs.logits
311-
action_log_probs = calc_action_log_probs(
312-
action_logits / self.generate_config["temperature"],
313-
inputs["input_ids"],
314-
num_action,
315-
self.plugin.shard_config,
319+
action_log_probs = torch.zeros(
320+
(inputs["input_ids"].size(0), num_action), device=action_logits.device
316321
)
322+
for i in range(action_log_probs.size(0)):
323+
# activation for log_softmax is too large if vocab size and sequence length are large
324+
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
325+
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
326+
action_log_probs[i, :] += calc_action_log_probs(
327+
action_logits[i : i + 1] / self.generate_config["temperature"],
328+
inputs["input_ids"][i : i + 1],
329+
num_action,
330+
self.plugin.shard_config,
331+
)[0]
317332
if "reference_action_log_probs" in inputs:
318333
per_token_kl = (
319334
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)

0 commit comments

Comments
 (0)