Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions src/forge/actors/reference_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from typing import Any

import torch

from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
from torch import nn
Expand All @@ -30,6 +28,8 @@
from torchtitan.experiments.forge.job_config import ForgeJobConfig
from transformers import AutoModelForCausalLM

from forge.controller import ForgeActor


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -93,7 +93,7 @@ async def setup(self):
async def forward(self, request: list[int], response: list[int]) -> torch.Tensor:
"""
Given a request and response tokens, return the log_probability of the
token_ids
token_ids, shape (completion_len, )
"""
model_parts = self.engine.model_parts
Expand Down Expand Up @@ -128,10 +128,11 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
logits = model_parts[0](input_ids)

# Compute logprobs
input_ids = input_ids[:, len(response) :]
input_ids = input_ids[:, len(request) :]
# (bsz=1, completion_len)
logprobs = compute_logprobs(logits, input_ids)

return logprobs
# (completion_len, )
return logprobs.squeeze(0)

return pred

Expand All @@ -140,14 +141,39 @@ async def forward(self, request: list[int], response: list[int]) -> torch.Tensor
def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
"""
Compute log probs of the completion input_ids given the logits of the whole sequence.
Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts.
Args:
logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model.
input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion.
Returns:
torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens.
# Truncate request logits and drop last
logits = logits[:, context_length - 1 : -1]
Raises:
ValueError: If the inferred context length is less than or equal to 0.
"""
context_len = logits.shape[1] - input_ids.shape[1]
completion_len = input_ids.shape[1]
if context_len <= 0:
raise ValueError(
"Context length must be greater than 0. Otherwise the probability of the first token is undefined."
)

# Compute logprobs
logprobs = torch.log_softmax(logits / temperature, dim=-1)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
# (bsz, completion_len, vocab_size)
logits = logits[:, context_len - 1 : -1, :]
assert logits.shape == (
input_ids.shape[0],
completion_len,
logits.shape[-1],
), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}"
token_logprobs = torch.log_softmax(logits / temperature, dim=-1)
# (bsz, completion_len, 1)
logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1))
# (bsz, completion_len)
logprobs = logprobs.squeeze(-1)

return logprobs

Expand Down
98 changes: 98 additions & 0 deletions tests/unit_tests/actors/test_reference_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Tests for reference_actor.py - compute_logprobs function
"""

import pytest
import torch

from forge.actors.reference_actor import compute_logprobs


class TestComputeLogprobs:
"""Test the compute_logprobs utility function."""

def test_compute_logprobs_basic(self):
"""Test basic logprobs computation."""
batch_size = 1
seq_len = 5
vocab_size = 1000
response_len = 3

logits = torch.randn(batch_size, seq_len, vocab_size)

# Create mock input_ids for response tokens
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))

result = compute_logprobs(logits, input_ids)

# Verify output shape and properties
assert isinstance(result, torch.Tensor)
assert result.shape == (batch_size, response_len)
assert torch.all(result <= 0) # Log probabilities should be <= 0

def test_compute_logprobs_with_temperature(self):
"""Test logprobs computation with temperature scaling."""
batch_size = 1
seq_len = 5
vocab_size = 1000
response_len = 3
temperature = 0.1

logits = torch.randn(batch_size, seq_len, vocab_size)
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))

result = compute_logprobs(logits, input_ids, temperature)

assert isinstance(result, torch.Tensor)
assert result.shape == (batch_size, response_len)
assert torch.all(result <= 0)
default_result = compute_logprobs(logits, input_ids)
assert not torch.allclose(result, default_result)

def test_compute_logprobs_single_token(self):
"""Test logprobs computation with single token response."""
batch_size = 1
seq_len = 5
vocab_size = 1000
response_len = 1

logits = torch.randn(batch_size, seq_len, vocab_size)
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))

result = compute_logprobs(logits, input_ids)

assert result.shape == (batch_size, response_len)
assert result.numel() == 1 # Single element

def test_compute_logprobs_empty_response(self):
"""Test logprobs computation with empty response."""
batch_size = 1
seq_len = 5
vocab_size = 1000
response_len = 0

logits = torch.randn(batch_size, seq_len, vocab_size)
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))

result = compute_logprobs(logits, input_ids)

assert result.shape == (batch_size, response_len)

def test_compute_logprobs_empty_prompt(self):
"""Test logprobs computation with empty prompt."""
batch_size = 1
vocab_size = 1000
prompt_len = 0
response_len = 5
seq_len = prompt_len + response_len

logits = torch.randn(batch_size, seq_len, vocab_size)
input_ids = torch.randint(0, vocab_size, (batch_size, response_len))
with pytest.raises(ValueError, match=r"(?i).*context length.*"):
_ = compute_logprobs(logits, input_ids)
Loading