Skip to content

Commit 2bb7bf5

Browse files
committed
messed up history...
1 parent e930b3b commit 2bb7bf5

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

src/forge/actors/reference_actor.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from typing import Any
1818

1919
import torch
20-
21-
from forge.controller import ForgeActor
2220
from monarch.actor import current_rank, current_size, endpoint
2321
from omegaconf import DictConfig, OmegaConf
2422
from torch import nn
@@ -30,6 +28,8 @@
3028
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3129
from transformers import AutoModelForCausalLM
3230

31+
from forge.controller import ForgeActor
32+
3333

3434
logger = logging.getLogger(__name__)
3535
logger.setLevel(logging.INFO)
@@ -93,7 +93,7 @@ async def setup(self):
9393
async def forward(self, request: list[int], response: list[int]) -> torch.Tensor:
9494
"""
9595
Given a request and response tokens, return the log_probability of the
96-
token_ids
96+
token_ids, shape (completion_len, )
9797
9898
"""
9999
model_parts = self.engine.model_parts
@@ -129,9 +129,10 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
129129

130130
# Compute logprobs
131131
input_ids = input_ids[:, len(request) :]
132+
# (bsz=1, completion_len)
132133
logprobs = compute_logprobs(logits, input_ids)
133-
134-
return logprobs
134+
# (completion_len, )
135+
return logprobs.squeeze(0)
135136

136137
return pred
137138

@@ -140,14 +141,39 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
140141
def compute_logprobs(
141142
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
142143
) -> torch.Tensor:
143-
context_length = logits.shape[1] - input_ids.shape[1]
144+
"""
145+
Compute log probs of the completion input_ids given the logits of the whole sequence.
146+
Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts.
147+
148+
Args:
149+
logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model.
150+
input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion.
151+
152+
Returns:
153+
torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens.
144154
145-
# Truncate request logits and drop last
146-
logits = logits[:, context_length - 1 : -1]
155+
Raises:
156+
ValueError: If the inferred context length is less than or equal to 0.
157+
"""
158+
context_len = logits.shape[1] - input_ids.shape[1]
159+
completion_len = input_ids.shape[1]
160+
if context_len <= 0:
161+
raise ValueError(
162+
"Context length must be greater than 0. Otherwise the probability of the first token is undefined."
163+
)
147164

148-
# Compute logprobs
149-
logprobs = torch.log_softmax(logits / temperature, dim=-1)
150-
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
165+
# (bsz, completion_len, vocab_size)
166+
logits = logits[:, context_len - 1 : -1, :]
167+
assert logits.shape == (
168+
input_ids.shape[0],
169+
completion_len,
170+
logits.shape[-1],
171+
), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}"
172+
token_logprobs = torch.log_softmax(logits / temperature, dim=-1)
173+
# (bsz, completion_len, 1)
174+
logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1))
175+
# (bsz, completion_len)
176+
logprobs = logprobs.squeeze(-1)
151177

152178
return logprobs
153179

0 commit comments

Comments
 (0)