Skip to content

Commit 245c8c2

Browse files
committed
implement memory efficient logprob
1 parent a960990 commit 245c8c2

File tree

2 files changed

+43
-48
lines changed

2 files changed

+43
-48
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import wandb
77
from coati.distributed.consumer import BaseConsumer
88
from coati.distributed.loss import PolicyLoss
9-
from coati.distributed.utils import calc_action_log_probs
9+
from coati.distributed.utils import memory_efficient_logprob
1010
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1111
from transformers import AutoModelForCausalLM, AutoTokenizer
1212

@@ -293,21 +293,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
293293
)
294294

295295
if self.booster.plugin.stage_manager.is_last_stage():
296-
reference_action_log_probs = torch.zeros(
297-
(input_ids_forward_micro_batch.size(0), num_action),
298-
device=input_ids_forward_micro_batch.device,
296+
reference_action_log_probs = memory_efficient_logprob(
297+
reference_model_outputs["outputs"]["logits"],
298+
input_ids_forward_micro_batch,
299+
num_action,
300+
shard_config=self.plugin.shard_config,
299301
)
300-
for i in range(reference_action_log_probs.size(0)):
301-
# activation for log_softmax is too large if vocab size and sequence length are large
302-
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
303-
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
304-
reference_action_log_probs[i, :] += calc_action_log_probs(
305-
reference_model_outputs["outputs"]["logits"][i : i + 1]
306-
/ self.generate_config["temperature"],
307-
input_ids_forward_micro_batch[i : i + 1],
308-
num_action,
309-
self.plugin.shard_config,
310-
)[0]
311302
else:
312303
# Dummy reference logprobs for data iterator.
313304
reference_action_log_probs = None
@@ -329,19 +320,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
329320

330321
def _criterion(outputs, inputs):
331322
action_logits = outputs.logits
332-
action_log_probs = torch.zeros(
333-
(inputs["input_ids"].size(0), num_action), device=action_logits.device
323+
action_log_probs = memory_efficient_logprob(
324+
action_logits,
325+
inputs["input_ids"],
326+
num_action,
327+
shard_config=self.plugin.shard_config,
334328
)
335-
for i in range(action_log_probs.size(0)):
336-
# activation for log_softmax is too large if vocab size and sequence length are large
337-
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
338-
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
339-
action_log_probs[i, :] += calc_action_log_probs(
340-
action_logits[i : i + 1] / self.generate_config["temperature"],
341-
inputs["input_ids"][i : i + 1],
342-
num_action,
343-
self.plugin.shard_config,
344-
)[0]
345329
if "reference_action_log_probs" in inputs:
346330
per_token_kl = (
347331
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
@@ -383,16 +367,15 @@ def _criterion(outputs, inputs):
383367
mean_kl.append(kl)
384368
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
385369
else:
386-
387370
policy_model_logits = self.policy_model(
388371
input_ids=input_ids_forward_micro_batch,
389372
attention_mask=attention_mask_forward_micro_batch,
390373
).logits
391-
action_log_probs = calc_action_log_probs(
374+
action_log_probs = memory_efficient_logprob(
392375
policy_model_logits / self.generate_config["temperature"],
393376
input_ids_forward_micro_batch,
394377
num_action,
395-
self.plugin.shard_config,
378+
shard_config=self.plugin.shard_config,
396379
)
397380

398381
if self.policy_loss_fn.beta > 0:
@@ -401,7 +384,7 @@ def _criterion(outputs, inputs):
401384
input_ids=input_ids_forward_micro_batch,
402385
attention_mask=attention_mask_forward_micro_batch,
403386
).logits
404-
reference_action_log_probs = calc_action_log_probs(
387+
reference_action_log_probs = memory_efficient_logprob(
405388
reference_model_logits / self.generate_config["temperature"],
406389
input_ids_forward_micro_batch,
407390
num_action,

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,31 +71,43 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
7171
return per_label_logps.squeeze(-1)
7272

7373

74-
def calc_action_log_probs(
74+
def memory_efficient_logprob(
7575
logits: torch.Tensor,
76-
sequences: torch.LongTensor,
77-
num_actions: int,
78-
shard_config,
76+
inputs: torch.Tensor,
77+
num_action: int,
78+
chunk_size: int = 2048,
79+
shard_config: Any = None,
7980
vocab_size: int = None,
8081
) -> torch.Tensor:
81-
"""Calculate action log probs.
82-
82+
"""
83+
Calculate action log probs in a memory-efficient way by processing in chunks.
8384
Args:
8485
logits (torch.Tensor): Output tensor of Actor.forward.logits.
85-
sequences (torch.LongTensor): Input sequences.
86-
num_actions (int): Number of actions.
87-
shard_config
88-
vocab_size
89-
90-
86+
inputs (torch.LongTensor): Input sequences.
87+
num_action (int): Number of actions.
88+
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
89+
shard_config: Shard configuration for distributed computation.
90+
vocab_size (int, optional): Vocabulary size. Default is None.
9191
Returns:
9292
torch.Tensor: Action log probs.
9393
"""
94-
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
95-
# logits: torch.Tensor, # [B, S, Vocab_size]
96-
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
97-
log_probs = log_probs.squeeze(-1)
98-
return log_probs[:, -num_actions:]
94+
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
95+
context_length = logits.size(1) - num_action
96+
for i in range(action_log_probs.size(0)):
97+
# loop over each sample in the micro-batch
98+
for start in range(context_length, logits.size(1), chunk_size):
99+
end = min(start + chunk_size, logits.size(1))
100+
# calculate log probs in chunks to save memory
101+
log_probs = dist_log_prob(
102+
inputs[i : i + 1, start - 1 : end],
103+
logits[i : i + 1, start - 1 : end],
104+
shard_config,
105+
vocab_size,
106+
logits.dtype,
107+
) # [1, chunk_size, 1]
108+
log_probs = log_probs.squeeze(-1)
109+
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
110+
return action_log_probs
99111

100112

101113
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:

0 commit comments

Comments
 (0)