@@ -705,6 +705,7 @@ def train(
705705 logits = self .model .lm_head (outputs .last_hidden_state )
706706 else :
707707 logits = outputs .logits
708+ del outputs
708709
709710 # Apply temperature scaling
710711 logits = self ._apply_temperature_scaling (logits )
@@ -771,6 +772,7 @@ def train(
771772 global_valid_seqs ,
772773 global_valid_toks ,
773774 )
775+ del logits
774776
775777 # skip the update for dummy batches
776778 if mb_idx < iterator_len :
@@ -1029,17 +1031,19 @@ def get_logprobs(
10291031 placements = [Shard (sequence_dim ), Shard (- 1 )],
10301032 )
10311033
1034+ logits = logits .to (torch .float32 )
10321035 token_logprobs = get_logprobs_from_vocab_parallel_logits (
1033- logits . to ( torch . float32 ) ,
1036+ logits ,
10341037 input_ids_dtensor ,
10351038 seq_index_tensor ,
10361039 )
10371040
10381041 assert token_logprobs .shape [1 ] == seq_len - 1
10391042 else :
10401043 if isinstance (logits , DTensor ):
1044+ logits = logits .to (torch .float32 )
10411045 token_logprobs = get_logprobs_from_vocab_parallel_logits (
1042- logits . to ( torch . float32 ) , input_ids
1046+ logits , input_ids
10431047 )
10441048 else :
10451049 # Extract logprobs for each token in the sequence by gathering the logprob
@@ -1049,16 +1053,16 @@ def get_logprobs(
10491053 # token_ids: [batch_size, sequence_length] - actual tokens
10501054 # Output shape: [batch_size, sequence_length] - logprob of each token given previous
10511055 # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length
1052-
1053- log_probs = torch .nn .functional .log_softmax (
1054- outputs .logits .to (torch .float32 ), dim = - 1
1055- )
1056+ logits = outputs .logits .to (torch .float32 )
1057+ log_probs = torch .nn .functional .log_softmax (logits , dim = - 1 )
10561058 next_tokens = input_ids [:, 1 :]
10571059 log_probs = log_probs [:, :- 1 ]
10581060 token_logprobs = log_probs .gather (
10591061 dim = - 1 , index = next_tokens .unsqueeze (- 1 )
10601062 ).squeeze (- 1 )
10611063
1064+ del outputs , logits
1065+
10621066 token_logprobs = torch .cat (
10631067 [torch .zeros_like (token_logprobs [:, :1 ]), token_logprobs ], dim = 1
10641068 )
0 commit comments