Skip to content

Commit e38cdf3

Browse files
committed
Rename test_compute_logprobs.py to test_ops.py and add align parameter tests
- Renamed test file to better reflect that it tests ops.py functions - Added three new tests for the align parameter in compute_logprobs: - test_align_parameter_false: validates align=False (pre-aligned logits) - test_align_parameter_true: validates align=True (slicing behavior) - test_align_comparison: verifies slicing logic produces correct results
1 parent 508e238 commit e38cdf3

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/unit_tests/util/test_compute_logprobs.py renamed to tests/unit_tests/util/test_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,56 @@ def test_compute_logprobs_empty_response(self):
109109

110110
result = compute_logprobs(logits, input_ids)
111111
assert result.shape == (batch_size, 0)
112+
113+
@pytest.mark.timeout(10)
114+
def test_align_parameter_false(self):
115+
"""Test with align=False (pre-aligned logits)."""
116+
# When align=False, logits are already aligned with input_ids
117+
# logits[:, i] predicts input_ids[:, i]
118+
batch_size, seq_len, vocab_size = 2, 3, 5
119+
logits = torch.randn(batch_size, seq_len, vocab_size)
120+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
121+
122+
result = compute_logprobs(logits, input_ids, align=False)
123+
124+
# Manual calculation without slicing
125+
expected = _textbook_log_softmax(logits, input_ids)
126+
127+
assert torch.allclose(result, expected, atol=1e-5)
128+
assert result.shape == input_ids.shape
129+
130+
@pytest.mark.timeout(10)
131+
def test_align_parameter_true(self):
132+
"""Test with align=True (default, needs slicing)."""
133+
# When align=True, logits need to be sliced to align with input_ids
134+
batch_size, full_seq_len, vocab_size = 2, 6, 5
135+
logits = torch.randn(batch_size, full_seq_len, vocab_size)
136+
137+
# We want log probs for just the last 3 tokens
138+
target_len = 3
139+
input_ids = torch.randint(0, vocab_size, (batch_size, target_len))
140+
141+
result = compute_logprobs(logits, input_ids, align=True)
142+
143+
# Manual calculation: align=True slices logits[:, -target_len-1:-1]
144+
sliced_logits = logits[:, -target_len - 1 : -1, :]
145+
expected = _textbook_log_softmax(sliced_logits, input_ids)
146+
147+
assert torch.allclose(result, expected, atol=1e-5)
148+
assert result.shape == input_ids.shape
149+
150+
@pytest.mark.timeout(10)
151+
def test_align_comparison(self):
152+
"""Test that align=True properly slices logits."""
153+
batch_size, seq_len, vocab_size = 1, 4, 10
154+
logits = torch.randn(batch_size, seq_len, vocab_size)
155+
input_ids = torch.randint(0, vocab_size, (batch_size, 2))
156+
157+
result_aligned = compute_logprobs(logits, input_ids, align=True)
158+
159+
# Manually slice the same way align=True does
160+
sliced_logits = logits[:, -input_ids.size(1) - 1 : -1, :]
161+
result_manual = compute_logprobs(sliced_logits, input_ids, align=False)
162+
163+
# Both should give the same result
164+
assert torch.allclose(result_aligned, result_manual, atol=1e-5)

0 commit comments

Comments
 (0)