Skip to content

Commit af6709b

Browse files
authored
[fix] add Gemma3 support (#1648)
* [fix] add Gemma3 support: final_logit_softcapping and meta device buffer handling * [fix] handle softcap=0 as disabled * [fix] use comment not docstring
1 parent 137e36a commit af6709b

File tree

3 files changed

+233
-0
lines changed

3 files changed

+233
-0
lines changed

src/prime_rl/trainer/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,11 @@ def can_reinit_empty_buffers(model: nn.Module):
379379
if len(buffer_names) == 1 and buffer_names[0] == "model.rotary_emb.inv_freq":
380380
return True
381381

382+
# Gemma3 model (has embed_scale and local rotary emb)
383+
gemma3_buffers = {"model.embed_tokens.embed_scale", "model.rotary_emb.inv_freq", "model.rotary_emb_local.inv_freq"}
384+
if set(buffer_names) == gemma3_buffers:
385+
return True
386+
382387
get_logger().warning(f"Model cannot be loaded using meta device because of buffers: {buffer_names}")
383388
return False
384389

@@ -390,6 +395,17 @@ def fix_model_post_empty(model: nn.Module):
390395
rotary_emb = model.model.rotary_emb
391396
inv_freq, rotary_emb.attention_scaling = rotary_emb.rope_init_fn(rotary_emb.config, rotary_emb.inv_freq.device)
392397
rotary_emb.inv_freq.copy_(inv_freq)
398+
# Gemma3 local rotary emb
399+
if "model.rotary_emb_local.inv_freq" in buffer_names:
400+
rotary_emb_local = model.model.rotary_emb_local
401+
inv_freq_local, rotary_emb_local.attention_scaling = rotary_emb_local.rope_init_fn(
402+
rotary_emb_local.config, rotary_emb_local.inv_freq.device
403+
)
404+
rotary_emb_local.inv_freq.copy_(inv_freq_local)
405+
# Gemma3 embed_scale (scalar computed from hidden_size)
406+
if "model.embed_tokens.embed_scale" in buffer_names:
407+
embed_scale = model.config.hidden_size**0.5
408+
model.model.embed_tokens.embed_scale.fill_(embed_scale)
393409

394410

395411
def reshard_module(model: nn.Module):

src/prime_rl/trainer/models/layers/lm_head.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ def inject_prime_lm_head(model: nn.Module, chunk_size: int | None = None) -> Non
222222
)
223223

224224
logger = get_logger()
225+
226+
# Check for Gemma-style softcapping - dispatch to specialized implementation
227+
final_logit_softcapping = getattr(model.config, "final_logit_softcapping", None)
228+
if final_logit_softcapping:
229+
from prime_rl.trainer.models.layers.lm_head_gemma import inject_gemma_lm_head
230+
231+
inject_gemma_lm_head(model, chunk_size, final_logit_softcapping)
232+
return
233+
225234
logger.info(f"Injecting Prime LM head with chunk size {chunk_size}")
226235

227236
# Replace the lm_head with the appropriate wrapper
@@ -235,6 +244,10 @@ def inject_prime_lm_head(model: nn.Module, chunk_size: int | None = None) -> Non
235244
model.lm_head.weight = old_lm_head.weight
236245
del old_lm_head
237246

247+
_patch_model_forward(model)
248+
249+
250+
def _patch_model_forward(model: nn.Module) -> None:
238251
# Patch the forward method to use the new lm_head with labels and temperature
239252
def new_forward(
240253
self: nn.Module,
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch import Tensor
6+
7+
from prime_rl.trainer.models.layers.lm_head import (
8+
PrimeLmOutput,
9+
_online_logsumexp_and_weighted_update,
10+
_patch_model_forward,
11+
)
12+
from prime_rl.utils.logger import get_logger
13+
14+
15+
class GemmaFusedOutputLinear(torch.nn.Linear):
16+
def __init__(self, in_features: int, out_features: int, chunk_size: int, softcap: float):
17+
super().__init__(in_features, out_features, bias=False)
18+
self.chunk_size = chunk_size
19+
self.softcap = softcap
20+
21+
def forward(
22+
self,
23+
hidden_states: torch.Tensor,
24+
labels: torch.Tensor | None = None,
25+
temperature: Tensor | None = None,
26+
) -> PrimeLmOutput:
27+
assert labels is not None, "GemmaFusedOutputLinear requires labels for chunked logprob computation"
28+
assert temperature is not None, "GemmaFusedOutputLinear requires per-token temperatures"
29+
30+
b, s, h = hidden_states.shape
31+
hidden_states = hidden_states.reshape(b * s, h).contiguous()
32+
labels = labels.reshape(b * s).contiguous()
33+
inv_t = 1.0 / temperature.reshape(b * s).contiguous() # [N]
34+
35+
logprobs, entropy = _GemmaChunkedLogProbEntropyFn.apply(
36+
hidden_states, self.weight, labels, inv_t, self.chunk_size, self.softcap
37+
)
38+
39+
logprobs = logprobs.reshape(b, s)
40+
entropy = entropy.reshape(b, s)
41+
return PrimeLmOutput(logprobs=logprobs, entropy=entropy)
42+
43+
44+
class GemmaVanillaOutputLinear(torch.nn.Linear):
45+
def __init__(self, in_features: int, out_features: int, softcap: float):
46+
super().__init__(in_features, out_features, bias=False)
47+
self.softcap = softcap
48+
49+
def forward(
50+
self, hidden_states: torch.Tensor, labels: torch.Tensor | None = None, temperature: Tensor | None = None
51+
) -> PrimeLmOutput:
52+
logits = super().forward(hidden_states)
53+
logits = self.softcap * torch.tanh(logits / self.softcap)
54+
return PrimeLmOutput(logits=logits)
55+
56+
57+
class _GemmaChunkedLogProbEntropyFn(torch.autograd.Function):
58+
@staticmethod
59+
def forward( # type: ignore[override]
60+
ctx,
61+
hidden: torch.Tensor, # [N, H]
62+
weight: torch.Tensor, # [V, H]
63+
labels: torch.Tensor, # [N]
64+
inv_temperature: torch.Tensor, # [N]
65+
chunk_size: int,
66+
softcap: float,
67+
) -> tuple[torch.Tensor, torch.Tensor]:
68+
"""
69+
Returns (per-token logprobs, per-token entropy) without materializing [N, V].
70+
71+
Important: entropy is computed from the *same* per-chunk logits used for the softmax
72+
normalization (no extra W @ hidden matmul).
73+
"""
74+
assert hidden.dim() == 2, f"expected hidden [N,H], got {tuple(hidden.shape)}"
75+
assert weight.dim() == 2, f"expected weight [V,H], got {tuple(weight.shape)}"
76+
assert labels.dim() == 1, f"expected labels [N], got {tuple(labels.shape)}"
77+
assert inv_temperature.dim() == 1, f"expected inv_temperature [N], got {tuple(inv_temperature.shape)}"
78+
assert hidden.shape[0] == labels.shape[0], "hidden/labels N mismatch"
79+
assert hidden.shape[1] == weight.shape[1], "hidden/weight H mismatch"
80+
assert hidden.shape[0] == inv_temperature.shape[0], "hidden/inv_temperature N mismatch"
81+
assert chunk_size > 0
82+
83+
device = hidden.device
84+
n = hidden.shape[0]
85+
vocab = weight.shape[0]
86+
87+
# Running stats in fp32.
88+
m = torch.full((n,), float("-inf"), device=device, dtype=torch.float32)
89+
s = torch.zeros((n,), device=device, dtype=torch.float32)
90+
t = torch.zeros((n,), device=device, dtype=torch.float32)
91+
target_logits = torch.zeros((n,), device=device, dtype=torch.float32)
92+
93+
inv_t_broadcast = inv_temperature.unsqueeze(-1) # [N, 1]
94+
95+
for start in range(0, vocab, chunk_size):
96+
end = min(start + chunk_size, vocab)
97+
w_chunk = weight[start:end] # [C, H]
98+
logits = hidden @ w_chunk.t() # [N, C] (model dtype)
99+
logits_f = logits.to(torch.float32) # [N, C] fp32
100+
101+
# Apply final logit softcapping (Gemma2/3) before temperature
102+
logits_f = softcap * torch.tanh(logits_f / softcap)
103+
logits_f = logits_f * inv_t_broadcast # [N, C] fp32
104+
105+
# Shared intermediates for logZ and entropy stats.
106+
m, s, t = _online_logsumexp_and_weighted_update(m, s, t, logits_f)
107+
108+
# Fill target logits for labels that fall in this chunk.
109+
mask = (labels >= start) & (labels < end)
110+
if torch.any(mask):
111+
idx = (labels[mask] - start).to(torch.long)
112+
target_logits[mask] = logits_f[mask, idx]
113+
114+
logz = m + torch.log(s)
115+
logprobs = target_logits - logz
116+
entropy = logz - (t / s)
117+
118+
# Save for backward (recompute logits per chunk for grad)
119+
ctx.save_for_backward(hidden, weight, labels, logz)
120+
ctx.inv_temperature = inv_temperature
121+
ctx.chunk_size = chunk_size
122+
ctx.softcap = softcap
123+
124+
# Return fp32 for numerical stability (matching baseline behavior).
125+
return logprobs, entropy
126+
127+
@staticmethod
128+
def backward(ctx, grad_logprobs: torch.Tensor, grad_entropy: torch.Tensor | None):
129+
assert grad_entropy is None or torch.all(grad_entropy == 0.0), (
130+
"Backward through entropy is not implemented in GemmaFusedOutputLinear"
131+
)
132+
133+
hidden, weight, labels, logz = ctx.saved_tensors
134+
inv_temperature: torch.Tensor = ctx.inv_temperature # [N]
135+
chunk_size: int = ctx.chunk_size
136+
softcap: float = ctx.softcap
137+
138+
n, h = hidden.shape
139+
vocab = weight.shape[0]
140+
141+
grad_hidden = torch.zeros_like(hidden)
142+
grad_weight = torch.zeros_like(weight)
143+
144+
g = grad_logprobs.to(torch.float32) # [N] fp32 for stable scaling
145+
146+
inv_t_broadcast = inv_temperature.unsqueeze(-1) # [N, 1]
147+
148+
for start in range(0, vocab, chunk_size):
149+
end = min(start + chunk_size, vocab)
150+
w_chunk = weight[start:end] # [C, H]
151+
152+
logits = hidden @ w_chunk.t() # [N, C] (model dtype)
153+
logits_f = logits.to(torch.float32) # [N, C] fp32
154+
155+
# Apply final logit softcapping (Gemma2/3) before temperature
156+
tanh_val = torch.tanh(logits_f / softcap)
157+
logits_f = softcap * tanh_val
158+
logits_f = logits_f * inv_t_broadcast # [N, C] fp32
159+
160+
# p = softmax(logits_f) chunk = exp(logits_f - logz)
161+
p = torch.exp(logits_f - logz.unsqueeze(-1)) # [N, C] fp32
162+
163+
# dL/dlogits = g * (1_{label} - p)
164+
grad_logits = (-g).unsqueeze(-1) * p # [N, C] fp32
165+
mask = (labels >= start) & (labels < end)
166+
if torch.any(mask):
167+
idx = (labels[mask] - start).to(torch.long)
168+
grad_logits[mask, idx] += g[mask]
169+
170+
# Chain through temperature scaling
171+
grad_logits = grad_logits * inv_t_broadcast
172+
173+
# Chain through softcapping: d/dx[c*tanh(x/c)] = 1 - tanh^2(x/c)
174+
grad_logits = grad_logits * (1 - tanh_val**2)
175+
176+
grad_hidden.add_(grad_logits.to(hidden.dtype) @ w_chunk)
177+
grad_w_chunk = grad_logits.to(weight.dtype).t() @ hidden # [C, H]
178+
grad_weight[start:end].add_(grad_w_chunk)
179+
180+
return grad_hidden, grad_weight, None, None, None, None
181+
182+
183+
def inject_gemma_lm_head(model: nn.Module, chunk_size: int | None, softcap: float) -> None:
184+
logger = get_logger()
185+
logger.info(f"Injecting Gemma LM head with chunk size {chunk_size}, softcap={softcap}")
186+
187+
old_lm_head = model.lm_head
188+
if chunk_size is not None:
189+
model.lm_head = GemmaFusedOutputLinear(
190+
in_features=old_lm_head.in_features,
191+
out_features=old_lm_head.out_features,
192+
chunk_size=chunk_size,
193+
softcap=softcap,
194+
)
195+
else:
196+
model.lm_head = GemmaVanillaOutputLinear(
197+
in_features=old_lm_head.in_features,
198+
out_features=old_lm_head.out_features,
199+
softcap=softcap,
200+
)
201+
model.lm_head.weight = old_lm_head.weight
202+
del old_lm_head
203+
204+
_patch_model_forward(model)

0 commit comments

Comments
 (0)