1717from typing import Any
1818
1919import torch
20-
21- from forge .controller import ForgeActor
2220from monarch .actor import current_rank , current_size , endpoint
2321from omegaconf import DictConfig , OmegaConf
2422from torch import nn
3028from torchtitan .experiments .forge .job_config import ForgeJobConfig
3129from transformers import AutoModelForCausalLM
3230
31+ from forge .controller import ForgeActor
32+
3333
3434logger = logging .getLogger (__name__ )
3535logger .setLevel (logging .INFO )
@@ -93,7 +93,7 @@ async def setup(self):
9393 async def forward (self , request : list [int ], response : list [int ]) -> torch .Tensor :
9494 """
9595 Given a request and response tokens, return the log_probability of the
96- token_ids
96+ token_ids, shape (completion_len, )
9797
9898 """
9999 model_parts = self .engine .model_parts
@@ -129,9 +129,10 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
129129
130130 # Compute logprobs
131131 input_ids = input_ids [:, len (request ) :]
132+ # (bsz=1, completion_len)
132133 logprobs = compute_logprobs (logits , input_ids )
133-
134- return logprobs
134+ # (completion_len, )
135+ return logprobs . squeeze ( 0 )
135136
136137 return pred
137138
@@ -140,14 +141,39 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
140141def compute_logprobs (
141142 logits : torch .Tensor , input_ids : torch .Tensor , temperature : float = 1.0
142143) -> torch .Tensor :
143- context_length = logits .shape [1 ] - input_ids .shape [1 ]
144+ """
145+ Compute log probs of the completion input_ids given the logits of the whole sequence.
146+ Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts.
147+
148+ Args:
149+ logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model.
150+ input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion.
151+
152+ Returns:
153+ torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens.
144154
145- # Truncate request logits and drop last
146- logits = logits [:, context_length - 1 : - 1 ]
155+ Raises:
156+ ValueError: If the inferred context length is less than or equal to 0.
157+ """
158+ context_len = logits .shape [1 ] - input_ids .shape [1 ]
159+ completion_len = input_ids .shape [1 ]
160+ if context_len <= 0 :
161+ raise ValueError (
162+ "Context length must be greater than 0. Otherwise the probability of the first token is undefined."
163+ )
147164
148- # Compute logprobs
149- logprobs = torch .log_softmax (logits / temperature , dim = - 1 )
150- logprobs = torch .gather (logprobs , 2 , input_ids .unsqueeze (- 1 )).squeeze (- 1 )
165+ # (bsz, completion_len, vocab_size)
166+ logits = logits [:, context_len - 1 : - 1 , :]
167+ assert logits .shape == (
168+ input_ids .shape [0 ],
169+ completion_len ,
170+ logits .shape [- 1 ],
171+ ), f"logits shape incorrect, { logits .shape = } , { input_ids .shape = } , { logits .shape [- 1 ]= } "
172+ token_logprobs = torch .log_softmax (logits / temperature , dim = - 1 )
173+ # (bsz, completion_len, 1)
174+ logprobs = torch .gather (token_logprobs , 2 , input_ids .unsqueeze (- 1 ))
175+ # (bsz, completion_len)
176+ logprobs = logprobs .squeeze (- 1 )
151177
152178 return logprobs
153179
0 commit comments