|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | +import torch.nn.functional as F |
| 10 | +from forge.util.ops import compute_logprobs |
| 11 | + |
| 12 | + |
| 13 | +def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor): |
| 14 | + # Helper: Textbook Log Softmax |
| 15 | + log_probs = F.log_softmax(logits, dim=-1) |
| 16 | + return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) |
| 17 | + |
| 18 | + |
| 19 | +class TestComputeLogprobs: |
| 20 | + def test_single_batch_item(self): |
| 21 | + """Test with single batch item.""" |
| 22 | + # Shape: (1, 2, 3) |
| 23 | + logits = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) |
| 24 | + # Shape: (1, 1) |
| 25 | + input_ids = torch.tensor([[1]]) |
| 26 | + result = compute_logprobs(logits, input_ids) |
| 27 | + |
| 28 | + # Manual calculation |
| 29 | + expected_logits = torch.tensor([[[1.0, 2.0, 3.0]]]) |
| 30 | + expected = _textbook_log_softmax(expected_logits, input_ids) |
| 31 | + |
| 32 | + assert torch.allclose(result, expected, atol=1e-5) |
| 33 | + assert result.shape == (1, 1) |
| 34 | + |
| 35 | + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 36 | + |
| 37 | + # Shape: (1, 3, 3) |
| 38 | + logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]) |
| 39 | + # Shape: (1, 2) |
| 40 | + input_ids = torch.tensor([[2, 0]]) |
| 41 | + result = compute_logprobs(logits, input_ids) |
| 42 | + |
| 43 | + # Manual calculation |
| 44 | + expected_logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]) |
| 45 | + expected = _textbook_log_softmax(expected_logits, input_ids) |
| 46 | + |
| 47 | + assert torch.allclose(result, expected, atol=1e-5) |
| 48 | + assert result.shape == (1, 2) |
| 49 | + |
| 50 | + @pytest.mark.timeout(10) |
| 51 | + def test_multi_batch(self): |
| 52 | + """Test with multiple batch items.""" |
| 53 | + # Shape: (2, 2, 3) |
| 54 | + logits = torch.tensor( |
| 55 | + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]]] |
| 56 | + ) |
| 57 | + # Shape: (2, 1) |
| 58 | + input_ids = torch.tensor([[1], [2]]) |
| 59 | + result = compute_logprobs(logits, input_ids) |
| 60 | + |
| 61 | + # Manual calculation |
| 62 | + expected_logits = torch.tensor([[[1.0, 2.0, 3.0]], [[0.5, 1.5, 2.5]]]) |
| 63 | + expected = _textbook_log_softmax(expected_logits, input_ids) |
| 64 | + |
| 65 | + assert torch.allclose(result, expected, atol=1e-5) |
| 66 | + assert result.shape == (2, 1) |
| 67 | + |
| 68 | + @pytest.mark.timeout(10) |
| 69 | + def test_temperature(self): |
| 70 | + """Test with different temperature values.""" |
| 71 | + batch_size, seq_len, vocab_size = 2, 4, 6 |
| 72 | + logits = torch.randn(batch_size, seq_len, vocab_size) |
| 73 | + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len - 1)) |
| 74 | + |
| 75 | + # Manual calculation with temperature scaling |
| 76 | + def _manual(temperature: float): |
| 77 | + expected_logits = logits[:, 0:-1] / temperature |
| 78 | + return _textbook_log_softmax(expected_logits, input_ids) |
| 79 | + |
| 80 | + temperatures = [1.0, 2.0, 4.5] |
| 81 | + for temperature in temperatures: |
| 82 | + result = compute_logprobs(logits, input_ids, temperature=temperature) |
| 83 | + expected = _manual(temperature) |
| 84 | + assert torch.allclose(result, expected, atol=1e-5) |
| 85 | + assert result.shape == input_ids.shape |
| 86 | + |
| 87 | + @pytest.mark.timeout(10) |
| 88 | + def test_edge_cases(self): |
| 89 | + """Test edge cases.""" |
| 90 | + # Test with very large values (numerical stability) |
| 91 | + logits = torch.tensor([[[1000.0, 2000.0], [1500.0, 2500.0]]]) |
| 92 | + input_ids = torch.tensor([[0]]) |
| 93 | + result = compute_logprobs(logits, input_ids) |
| 94 | + # Should not be NaN or inf |
| 95 | + assert torch.isfinite(result).all() |
| 96 | + |
| 97 | + # Test with very small values |
| 98 | + logits = torch.tensor([[[-1000.0, -2000.0], [-1500.0, -2500.0]]]) |
| 99 | + input_ids = torch.tensor([[1]]) |
| 100 | + result = compute_logprobs(logits, input_ids) |
| 101 | + # Should not be NaN or inf |
| 102 | + assert torch.isfinite(result).all() |
| 103 | + |
| 104 | + def test_compute_logprobs_empty_response(self): |
| 105 | + """Test logprobs computation with empty response.""" |
| 106 | + batch_size, seq_len, vocab_size = 1, 5, 1000 |
| 107 | + logits = torch.randn(batch_size, seq_len, vocab_size) |
| 108 | + input_ids = torch.tensor([[]]) |
| 109 | + |
| 110 | + result = compute_logprobs(logits, input_ids) |
| 111 | + assert result.shape == (batch_size, 0) |
0 commit comments