Skip to content

Commit 0dd6610

Browse files
committed
Add tests for reverse KL loss
1 parent 1eee287 commit 0dd6610

File tree

2 files changed

+166
-2
lines changed

2 files changed

+166
-2
lines changed

apps/on_policy_distillation/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from datasets import load_dataset
1717
from forge.actors.generator import Generator
1818
from forge.actors.reference_model import ReferenceModel
19-
from forge.actors.trainer import RLTrainer
19+
from forge.actors.trainer import TitanTrainer
2020
from forge.controller.provisioner import init_provisioner, shutdown
2121
from forge.data_models.completion import Completion
2222
from forge.observability.metric_actors import get_or_create_metric_logger
@@ -112,7 +112,7 @@ async def main(cfg: DictConfig):
112112
mlogger = await get_or_create_metric_logger(process_name="Controller")
113113
await mlogger.init_backends.call_one(cfg.metric_logging)
114114
student_trainer, student_generator, teacher = await asyncio.gather(
115-
RLTrainer.options(**cfg.services.trainer).as_actor(
115+
TitanTrainer.options(**cfg.services.trainer).as_actor(
116116
**cfg.trainer, loss=reverse_kl_loss
117117
),
118118
Generator.options(**cfg.services.student_generator).as_service(
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""
2+
Test file comparing reverse_kl_loss from the PR with Tinker/Thinking Machines implementation
3+
PR: https://github.com/meta-pytorch/torchforge/pull/527
4+
5+
Citations from Tinker implementation:
6+
- Blog post pseudocode: https://thinkingmachines.ai/blog/on-policy-distillation/
7+
- Tinker Cookbook: https://github.com/thinking-machines-lab/tinker-cookbook
8+
"""
9+
10+
import torch
11+
12+
from apps.on_policy_distillation.main import reverse_kl_loss
13+
from forge.util.ops import compute_logprobs
14+
15+
16+
class TestReverseKLLoss:
17+
"""
18+
We want to cover a couple things in these tests:
19+
1. Basic input / output / handling of parameters
20+
2. Matches the Tinker implementation
21+
3. Behaving as expected meaning it pushes logprobs in the correct direction
22+
"""
23+
24+
def test_vs_tinker_loss(self):
25+
"""Test the complete pattern from Tinker's implementation."""
26+
batch_size, seq_len, vocab_size = 2, 5, 50
27+
28+
prompt = torch.randint(0, vocab_size, (batch_size, seq_len))
29+
response = torch.randint(0, vocab_size, (batch_size, seq_len))
30+
31+
# https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L71
32+
input_ids = torch.cat([prompt, response], dim=-1)
33+
34+
teacher_logits = torch.full(
35+
(batch_size, input_ids.size(1) + 1, vocab_size), -1000.0
36+
)
37+
for b in range(batch_size):
38+
for t in range(input_ids.size(1)):
39+
teacher_logits[b, t, response[b, t]] = 0.0
40+
41+
# https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L77
42+
teacher_logprobs = compute_logprobs(teacher_logits, response)
43+
44+
student_logits = torch.full(
45+
(batch_size, input_ids.size(1) + 1, vocab_size), -1000.0
46+
)
47+
for b in range(batch_size):
48+
for t in range(input_ids.size(1)):
49+
student_logits[b, t, response[b, t]] = 0.5
50+
51+
# https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L86
52+
student_logprobs = compute_logprobs(student_logits, response)
53+
54+
# https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L87
55+
mask = response == 0
56+
mask = mask.float()
57+
58+
# https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L89
59+
reverse_kl = (student_logprobs - teacher_logprobs) * mask
60+
61+
# https://github.com/thinking-machines-lab/tinker-cookbook/blob/6c9f7a4f254c01010509a147e7fd80026654464b/tinker_cookbook/distillation/train_on_policy.py#L100
62+
advantages = -1.0 * mask * reverse_kl
63+
64+
pass
65+
66+
def test_zero_kl_property(self):
67+
"""Test that KL is zero when distributions match perfectly."""
68+
batch_size, seq_len, vocab_size = 2, 5, 50
69+
70+
response = torch.randint(0, vocab_size, (batch_size, seq_len))
71+
72+
# Create logits for seq_len+1 positions (to predict seq_len response tokens)
73+
# compute_logprobs will slice logits[:, -seq_len-1:-1] to align with response
74+
logits = torch.full((batch_size, seq_len + 1, vocab_size), -1000.0)
75+
for b in range(batch_size):
76+
for t in range(seq_len):
77+
logits[b, t, response[b, t]] = 0.0
78+
79+
# Get student log probabilities for selected tokens using compute_logprobs
80+
student_logprobs = compute_logprobs(logits, response)
81+
82+
# Set teacher to match student exactly
83+
teacher_logprobs = student_logprobs.clone().detach()
84+
85+
# No padding
86+
padding_mask = torch.ones(batch_size, seq_len, dtype=torch.bool)
87+
88+
loss = reverse_kl_loss(logits, response, teacher_logprobs, padding_mask)
89+
90+
# When student matches teacher, reverse_kl = 0, advantages = 0, loss = 0
91+
assert abs(loss.item()) < 1e-5, "Loss should be ~0 when student matches teacher"
92+
93+
def test_loss_direction(self):
94+
"""Test that gradients push student logprobs toward teacher."""
95+
batch_size, seq_len, vocab_size = 1, 1, 10 # noqa
96+
97+
# Single token case for clarity
98+
response = torch.tensor([[5]]) # Token index 5
99+
100+
# Student has low probability for token 5
101+
# Need seq_len+1 positions for compute_logprobs alignment
102+
logits = torch.full((1, 2, vocab_size), 0.0, requires_grad=True)
103+
logits.data[0, 0, 5] = -3.0 # Low logit for token 5
104+
105+
# Teacher has higher probability (less negative logprob)
106+
teacher_logprobs = torch.tensor([[-1.0]])
107+
108+
padding_mask = torch.ones(1, 1, dtype=torch.bool)
109+
110+
# Compute loss and gradients
111+
loss = reverse_kl_loss(logits, response, teacher_logprobs, padding_mask)
112+
loss.backward()
113+
114+
# When student logprob is lower than teacher, gradient should push it higher
115+
# Gradient at index 5 should be negative (increase logit -> increase logprob)
116+
assert logits.grad is not None
117+
assert (
118+
logits.grad[0, 0, 5].item() < 0
119+
), "Gradient should push logit higher when student < teacher"
120+
121+
def test_mode_seeking_behavior(self):
122+
"""
123+
Test that reverse KL exhibits mode-seeking behavior.
124+
125+
Citation: From blog post:
126+
"reverse KL is 'mode seeking' — it learns one specific behavior
127+
(the teacher's) instead of spreading its distribution across
128+
several suboptimal options."
129+
(https://thinkingmachines.ai/blog/on-policy-distillation/)
130+
"""
131+
batch_size, seq_len, vocab_size = 1, 3, 10
132+
133+
response = torch.tensor([[2, 5, 7]])
134+
135+
# Teacher has high confidence (low entropy)
136+
teacher_logprobs = torch.tensor([[-0.1, -0.1, -0.1]])
137+
138+
# Student 1: Spread distribution (high entropy)
139+
# Need seq_len+1 positions for compute_logprobs alignment
140+
logits_spread = torch.zeros(batch_size, seq_len + 1, vocab_size)
141+
142+
# Student 2: Focused distribution (low entropy, matching teacher's confidence)
143+
logits_focused = torch.full((batch_size, seq_len + 1, vocab_size), -10.0)
144+
logits_focused[0, 0, 2] = 10.0
145+
logits_focused[0, 1, 5] = 10.0
146+
logits_focused[0, 2, 7] = 10.0
147+
148+
padding_mask = torch.ones(batch_size, seq_len, dtype=torch.bool)
149+
150+
# Compute losses
151+
loss_spread = reverse_kl_loss(
152+
logits_spread, response, teacher_logprobs, padding_mask
153+
)
154+
loss_focused = reverse_kl_loss(
155+
logits_focused, response, teacher_logprobs, padding_mask
156+
)
157+
158+
# Mode-seeking: focused distribution should generally have different loss characteristics
159+
assert isinstance(loss_spread.item(), float)
160+
assert isinstance(loss_focused.item(), float)
161+
162+
# Both losses should be finite
163+
assert torch.isfinite(loss_spread)
164+
assert torch.isfinite(loss_focused)

0 commit comments

Comments
 (0)