Skip to content

Commit 0e19841

Browse files
authored
Update reference_actor.py (#125)
* Update reference_actor.py * add tests, clean up * messed up history...
1 parent 75b5fd6 commit 0e19841

File tree

2 files changed

+136
-12
lines changed

2 files changed

+136
-12
lines changed

src/forge/actors/reference_actor.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from typing import Any
1818

1919
import torch
20-
21-
from forge.controller import ForgeActor
2220
from monarch.actor import current_rank, current_size, endpoint
2321
from omegaconf import DictConfig, OmegaConf
2422
from torch import nn
@@ -30,6 +28,8 @@
3028
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3129
from transformers import AutoModelForCausalLM
3230

31+
from forge.controller import ForgeActor
32+
3333

3434
logger = logging.getLogger(__name__)
3535
logger.setLevel(logging.INFO)
@@ -93,7 +93,7 @@ async def setup(self):
9393
async def forward(self, request: list[int], response: list[int]) -> torch.Tensor:
9494
"""
9595
Given a request and response tokens, return the log_probability of the
96-
token_ids
96+
token_ids, shape (completion_len, )
9797
9898
"""
9999
model_parts = self.engine.model_parts
@@ -128,10 +128,11 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
128128
logits = model_parts[0](input_ids)
129129

130130
# Compute logprobs
131-
input_ids = input_ids[:, len(response) :]
131+
input_ids = input_ids[:, len(request) :]
132+
# (bsz=1, completion_len)
132133
logprobs = compute_logprobs(logits, input_ids)
133-
134-
return logprobs
134+
# (completion_len, )
135+
return logprobs.squeeze(0)
135136

136137
return pred
137138

@@ -140,14 +141,39 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
140141
def compute_logprobs(
141142
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
142143
) -> torch.Tensor:
143-
context_length = logits.shape[1] - input_ids.shape[1]
144+
"""
145+
Compute log probs of the completion input_ids given the logits of the whole sequence.
146+
Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts.
147+
148+
Args:
149+
logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model.
150+
input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion.
151+
152+
Returns:
153+
torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens.
144154
145-
# Truncate request logits and drop last
146-
logits = logits[:, context_length - 1 : -1]
155+
Raises:
156+
ValueError: If the inferred context length is less than or equal to 0.
157+
"""
158+
context_len = logits.shape[1] - input_ids.shape[1]
159+
completion_len = input_ids.shape[1]
160+
if context_len <= 0:
161+
raise ValueError(
162+
"Context length must be greater than 0. Otherwise the probability of the first token is undefined."
163+
)
147164

148-
# Compute logprobs
149-
logprobs = torch.log_softmax(logits / temperature, dim=-1)
150-
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
165+
# (bsz, completion_len, vocab_size)
166+
logits = logits[:, context_len - 1 : -1, :]
167+
assert logits.shape == (
168+
input_ids.shape[0],
169+
completion_len,
170+
logits.shape[-1],
171+
), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}"
172+
token_logprobs = torch.log_softmax(logits / temperature, dim=-1)
173+
# (bsz, completion_len, 1)
174+
logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1))
175+
# (bsz, completion_len)
176+
logprobs = logprobs.squeeze(-1)
151177

152178
return logprobs
153179

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Tests for reference_actor.py - compute_logprobs function
9+
"""
10+
11+
import pytest
12+
import torch
13+
14+
from forge.actors.reference_actor import compute_logprobs
15+
16+
17+
class TestComputeLogprobs:
18+
"""Test the compute_logprobs utility function."""
19+
20+
def test_compute_logprobs_basic(self):
21+
"""Test basic logprobs computation."""
22+
batch_size = 1
23+
seq_len = 5
24+
vocab_size = 1000
25+
response_len = 3
26+
27+
logits = torch.randn(batch_size, seq_len, vocab_size)
28+
29+
# Create mock input_ids for response tokens
30+
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))
31+
32+
result = compute_logprobs(logits, input_ids)
33+
34+
# Verify output shape and properties
35+
assert isinstance(result, torch.Tensor)
36+
assert result.shape == (batch_size, response_len)
37+
assert torch.all(result <= 0) # Log probabilities should be <= 0
38+
39+
def test_compute_logprobs_with_temperature(self):
40+
"""Test logprobs computation with temperature scaling."""
41+
batch_size = 1
42+
seq_len = 5
43+
vocab_size = 1000
44+
response_len = 3
45+
temperature = 0.1
46+
47+
logits = torch.randn(batch_size, seq_len, vocab_size)
48+
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))
49+
50+
result = compute_logprobs(logits, input_ids, temperature)
51+
52+
assert isinstance(result, torch.Tensor)
53+
assert result.shape == (batch_size, response_len)
54+
assert torch.all(result <= 0)
55+
default_result = compute_logprobs(logits, input_ids)
56+
assert not torch.allclose(result, default_result)
57+
58+
def test_compute_logprobs_single_token(self):
59+
"""Test logprobs computation with single token response."""
60+
batch_size = 1
61+
seq_len = 5
62+
vocab_size = 1000
63+
response_len = 1
64+
65+
logits = torch.randn(batch_size, seq_len, vocab_size)
66+
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))
67+
68+
result = compute_logprobs(logits, input_ids)
69+
70+
assert result.shape == (batch_size, response_len)
71+
assert result.numel() == 1 # Single element
72+
73+
def test_compute_logprobs_empty_response(self):
74+
"""Test logprobs computation with empty response."""
75+
batch_size = 1
76+
seq_len = 5
77+
vocab_size = 1000
78+
response_len = 0
79+
80+
logits = torch.randn(batch_size, seq_len, vocab_size)
81+
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))
82+
83+
result = compute_logprobs(logits, input_ids)
84+
85+
assert result.shape == (batch_size, response_len)
86+
87+
def test_compute_logprobs_empty_prompt(self):
88+
"""Test logprobs computation with empty prompt."""
89+
batch_size = 1
90+
vocab_size = 1000
91+
prompt_len = 0
92+
response_len = 5
93+
seq_len = prompt_len + response_len
94+
95+
logits = torch.randn(batch_size, seq_len, vocab_size)
96+
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))
97+
with pytest.raises(ValueError, match=r"(?i).*context length.*"):
98+
_ = compute_logprobs(logits, input_ids)

0 commit comments

Comments
 (0)