Skip to content

Commit 2b87def

Browse files
authored
fix: OOM in deepscaler1.5b with sequence length = 16/24k (#875)
Signed-off-by: Qidong Su <[email protected]>
1 parent fecf71e commit 2b87def

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,3 @@ policy:
4545
gpu_memory_utilization: 0.8
4646
enforce_eager: True
4747
max_model_len: ${policy.max_total_sequence_length}
48-
49-
cluster:
50-
gpus_per_node: 8
51-
num_nodes: 4

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ def train(
705705
logits = self.model.lm_head(outputs.last_hidden_state)
706706
else:
707707
logits = outputs.logits
708+
del outputs
708709

709710
# Apply temperature scaling
710711
logits = self._apply_temperature_scaling(logits)
@@ -771,6 +772,7 @@ def train(
771772
global_valid_seqs,
772773
global_valid_toks,
773774
)
775+
del logits
774776

775777
# skip the update for dummy batches
776778
if mb_idx < iterator_len:
@@ -1029,17 +1031,19 @@ def get_logprobs(
10291031
placements=[Shard(sequence_dim), Shard(-1)],
10301032
)
10311033

1034+
logits = logits.to(torch.float32)
10321035
token_logprobs = get_logprobs_from_vocab_parallel_logits(
1033-
logits.to(torch.float32),
1036+
logits,
10341037
input_ids_dtensor,
10351038
seq_index_tensor,
10361039
)
10371040

10381041
assert token_logprobs.shape[1] == seq_len - 1
10391042
else:
10401043
if isinstance(logits, DTensor):
1044+
logits = logits.to(torch.float32)
10411045
token_logprobs = get_logprobs_from_vocab_parallel_logits(
1042-
logits.to(torch.float32), input_ids
1046+
logits, input_ids
10431047
)
10441048
else:
10451049
# Extract logprobs for each token in the sequence by gathering the logprob
@@ -1049,16 +1053,16 @@ def get_logprobs(
10491053
# token_ids: [batch_size, sequence_length] - actual tokens
10501054
# Output shape: [batch_size, sequence_length] - logprob of each token given previous
10511055
# We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length
1052-
1053-
log_probs = torch.nn.functional.log_softmax(
1054-
outputs.logits.to(torch.float32), dim=-1
1055-
)
1056+
logits = outputs.logits.to(torch.float32)
1057+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
10561058
next_tokens = input_ids[:, 1:]
10571059
log_probs = log_probs[:, :-1]
10581060
token_logprobs = log_probs.gather(
10591061
dim=-1, index=next_tokens.unsqueeze(-1)
10601062
).squeeze(-1)
10611063

1064+
del outputs, logits
1065+
10621066
token_logprobs = torch.cat(
10631067
[torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1
10641068
)

0 commit comments

Comments
 (0)