Skip to content

Commit e3d56cb

Browse files
committed
implement memory efficient logprob
1 parent 30a6859 commit e3d56cb

File tree

2 files changed

+43
-48
lines changed

2 files changed

+43
-48
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.utils import calc_action_log_probs
9+
from coati.distributed.utils import memory_efficient_logprob
1010
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1111
from transformers import AutoModelForCausalLM, AutoTokenizer
1212

@@ -280,21 +280,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
280280
)
281281

282282
if self.booster.plugin.stage_manager.is_last_stage():
283-
reference_action_log_probs = torch.zeros(
284-
(input_ids_forward_micro_batch.size(0), num_action),
285-
device=input_ids_forward_micro_batch.device,
283+
reference_action_log_probs = memory_efficient_logprob(
284+
reference_model_outputs["outputs"]["logits"],
285+
input_ids_forward_micro_batch,
286+
num_action,
287+
shard_config=self.plugin.shard_config,
286288
)
287-
for i in range(reference_action_log_probs.size(0)):
288-
# activation for log_softmax is too large if vocab size and sequence length are large
289-
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
290-
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
291-
reference_action_log_probs[i, :] += calc_action_log_probs(
292-
reference_model_outputs["outputs"]["logits"][i : i + 1]
293-
/ self.generate_config["temperature"],
294-
input_ids_forward_micro_batch[i : i + 1],
295-
num_action,
296-
self.plugin.shard_config,
297-
)[0]
298289
else:
299290
# Dummy reference logprobs for data iterator.
300291
reference_action_log_probs = None
@@ -316,19 +307,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
316307

317308
def _criterion(outputs, inputs):
318309
action_logits = outputs.logits
319-
action_log_probs = torch.zeros(
320-
(inputs["input_ids"].size(0), num_action), device=action_logits.device
310+
action_log_probs = memory_efficient_logprob(
311+
action_logits,
312+
inputs["input_ids"],
313+
num_action,
314+
shard_config=self.plugin.shard_config,
321315
)
322-
for i in range(action_log_probs.size(0)):
323-
# activation for log_softmax is too large if vocab size and sequence length are large
324-
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
325-
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
326-
action_log_probs[i, :] += calc_action_log_probs(
327-
action_logits[i : i + 1] / self.generate_config["temperature"],
328-
inputs["input_ids"][i : i + 1],
329-
num_action,
330-
self.plugin.shard_config,
331-
)[0]
332316
if "reference_action_log_probs" in inputs:
333317
per_token_kl = (
334318
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
@@ -370,16 +354,15 @@ def _criterion(outputs, inputs):
370354
mean_kl.append(kl)
371355
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
372356
else:
373-
374357
policy_model_logits = self.policy_model(
375358
input_ids=input_ids_forward_micro_batch,
376359
attention_mask=attention_mask_forward_micro_batch,
377360
).logits
378-
action_log_probs = calc_action_log_probs(
361+
action_log_probs = memory_efficient_logprob(
379362
policy_model_logits / self.generate_config["temperature"],
380363
input_ids_forward_micro_batch,
381364
num_action,
382-
self.plugin.shard_config,
365+
shard_config=self.plugin.shard_config,
383366
)
384367

385368
if self.policy_loss_fn.beta > 0:
@@ -388,7 +371,7 @@ def _criterion(outputs, inputs):
388371
input_ids=input_ids_forward_micro_batch,
389372
attention_mask=attention_mask_forward_micro_batch,
390373
).logits
391-
reference_action_log_probs = calc_action_log_probs(
374+
reference_action_log_probs = memory_efficient_logprob(
392375
reference_model_logits / self.generate_config["temperature"],
393376
input_ids_forward_micro_batch,
394377
num_action,

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
7171
return per_label_logps.squeeze(-1)
7272

7373

74-
def calc_action_log_probs(
74+
def memory_efficient_logprob(
7575
logits: torch.Tensor,
76-
sequences: torch.LongTensor,
77-
num_actions: int,
78-
shard_config,
76+
inputs: torch.Tensor,
77+
num_action: int,
78+
chunk_size: int = 2048,
79+
shard_config: Any = None,
7980
vocab_size: int = None,
8081
) -> torch.Tensor:
81-
"""Calculate action log probs.
82-
82+
"""
83+
Calculate action log probs in a memory-efficient way by processing in chunks.
8384
Args:
8485
logits (torch.Tensor): Output tensor of Actor.forward.logits.
85-
sequences (torch.LongTensor): Input sequences.
86-
num_actions (int): Number of actions.
87-
shard_config
88-
vocab_size
89-
90-
86+
inputs (torch.LongTensor): Input sequences.
87+
num_action (int): Number of actions.
88+
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
89+
shard_config: Shard configuration for distributed computation.
90+
vocab_size (int, optional): Vocabulary size. Default is None.
9191
Returns:
9292
torch.Tensor: Action log probs.
9393
"""
94-
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
95-
# logits: torch.Tensor, # [B, S, Vocab_size]
96-
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
97-
log_probs = log_probs.squeeze(-1)
98-
return log_probs[:, -num_actions:]
94+
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
95+
context_length = logits.size(1) - num_action
96+
for i in range(action_log_probs.size(0)):
97+
# loop over each sample in the micro-batch
98+
for start in range(context_length, logits.size(1), chunk_size):
99+
end = min(start + chunk_size, logits.size(1))
100+
# calculate log probs in chunks to save memory
101+
log_probs = dist_log_prob(
102+
inputs[i : i + 1, start - 1 : end],
103+
logits[i : i + 1, start - 1 : end],
104+
shard_config,
105+
vocab_size,
106+
logits.dtype,
107+
) # [1, chunk_size, 1]
108+
log_probs = log_probs.squeeze(-1)
109+
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
110+
return action_log_probs
99111

100112

101113
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:

0 commit comments

Comments
 (0)