|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torch import Tensor |
| 6 | + |
| 7 | +from prime_rl.trainer.models.layers.lm_head import ( |
| 8 | + PrimeLmOutput, |
| 9 | + _online_logsumexp_and_weighted_update, |
| 10 | + _patch_model_forward, |
| 11 | +) |
| 12 | +from prime_rl.utils.logger import get_logger |
| 13 | + |
| 14 | + |
| 15 | +class GemmaFusedOutputLinear(torch.nn.Linear): |
| 16 | + def __init__(self, in_features: int, out_features: int, chunk_size: int, softcap: float): |
| 17 | + super().__init__(in_features, out_features, bias=False) |
| 18 | + self.chunk_size = chunk_size |
| 19 | + self.softcap = softcap |
| 20 | + |
| 21 | + def forward( |
| 22 | + self, |
| 23 | + hidden_states: torch.Tensor, |
| 24 | + labels: torch.Tensor | None = None, |
| 25 | + temperature: Tensor | None = None, |
| 26 | + ) -> PrimeLmOutput: |
| 27 | + assert labels is not None, "GemmaFusedOutputLinear requires labels for chunked logprob computation" |
| 28 | + assert temperature is not None, "GemmaFusedOutputLinear requires per-token temperatures" |
| 29 | + |
| 30 | + b, s, h = hidden_states.shape |
| 31 | + hidden_states = hidden_states.reshape(b * s, h).contiguous() |
| 32 | + labels = labels.reshape(b * s).contiguous() |
| 33 | + inv_t = 1.0 / temperature.reshape(b * s).contiguous() # [N] |
| 34 | + |
| 35 | + logprobs, entropy = _GemmaChunkedLogProbEntropyFn.apply( |
| 36 | + hidden_states, self.weight, labels, inv_t, self.chunk_size, self.softcap |
| 37 | + ) |
| 38 | + |
| 39 | + logprobs = logprobs.reshape(b, s) |
| 40 | + entropy = entropy.reshape(b, s) |
| 41 | + return PrimeLmOutput(logprobs=logprobs, entropy=entropy) |
| 42 | + |
| 43 | + |
| 44 | +class GemmaVanillaOutputLinear(torch.nn.Linear): |
| 45 | + def __init__(self, in_features: int, out_features: int, softcap: float): |
| 46 | + super().__init__(in_features, out_features, bias=False) |
| 47 | + self.softcap = softcap |
| 48 | + |
| 49 | + def forward( |
| 50 | + self, hidden_states: torch.Tensor, labels: torch.Tensor | None = None, temperature: Tensor | None = None |
| 51 | + ) -> PrimeLmOutput: |
| 52 | + logits = super().forward(hidden_states) |
| 53 | + logits = self.softcap * torch.tanh(logits / self.softcap) |
| 54 | + return PrimeLmOutput(logits=logits) |
| 55 | + |
| 56 | + |
| 57 | +class _GemmaChunkedLogProbEntropyFn(torch.autograd.Function): |
| 58 | + @staticmethod |
| 59 | + def forward( # type: ignore[override] |
| 60 | + ctx, |
| 61 | + hidden: torch.Tensor, # [N, H] |
| 62 | + weight: torch.Tensor, # [V, H] |
| 63 | + labels: torch.Tensor, # [N] |
| 64 | + inv_temperature: torch.Tensor, # [N] |
| 65 | + chunk_size: int, |
| 66 | + softcap: float, |
| 67 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 68 | + """ |
| 69 | + Returns (per-token logprobs, per-token entropy) without materializing [N, V]. |
| 70 | +
|
| 71 | + Important: entropy is computed from the *same* per-chunk logits used for the softmax |
| 72 | + normalization (no extra W @ hidden matmul). |
| 73 | + """ |
| 74 | + assert hidden.dim() == 2, f"expected hidden [N,H], got {tuple(hidden.shape)}" |
| 75 | + assert weight.dim() == 2, f"expected weight [V,H], got {tuple(weight.shape)}" |
| 76 | + assert labels.dim() == 1, f"expected labels [N], got {tuple(labels.shape)}" |
| 77 | + assert inv_temperature.dim() == 1, f"expected inv_temperature [N], got {tuple(inv_temperature.shape)}" |
| 78 | + assert hidden.shape[0] == labels.shape[0], "hidden/labels N mismatch" |
| 79 | + assert hidden.shape[1] == weight.shape[1], "hidden/weight H mismatch" |
| 80 | + assert hidden.shape[0] == inv_temperature.shape[0], "hidden/inv_temperature N mismatch" |
| 81 | + assert chunk_size > 0 |
| 82 | + |
| 83 | + device = hidden.device |
| 84 | + n = hidden.shape[0] |
| 85 | + vocab = weight.shape[0] |
| 86 | + |
| 87 | + # Running stats in fp32. |
| 88 | + m = torch.full((n,), float("-inf"), device=device, dtype=torch.float32) |
| 89 | + s = torch.zeros((n,), device=device, dtype=torch.float32) |
| 90 | + t = torch.zeros((n,), device=device, dtype=torch.float32) |
| 91 | + target_logits = torch.zeros((n,), device=device, dtype=torch.float32) |
| 92 | + |
| 93 | + inv_t_broadcast = inv_temperature.unsqueeze(-1) # [N, 1] |
| 94 | + |
| 95 | + for start in range(0, vocab, chunk_size): |
| 96 | + end = min(start + chunk_size, vocab) |
| 97 | + w_chunk = weight[start:end] # [C, H] |
| 98 | + logits = hidden @ w_chunk.t() # [N, C] (model dtype) |
| 99 | + logits_f = logits.to(torch.float32) # [N, C] fp32 |
| 100 | + |
| 101 | + # Apply final logit softcapping (Gemma2/3) before temperature |
| 102 | + logits_f = softcap * torch.tanh(logits_f / softcap) |
| 103 | + logits_f = logits_f * inv_t_broadcast # [N, C] fp32 |
| 104 | + |
| 105 | + # Shared intermediates for logZ and entropy stats. |
| 106 | + m, s, t = _online_logsumexp_and_weighted_update(m, s, t, logits_f) |
| 107 | + |
| 108 | + # Fill target logits for labels that fall in this chunk. |
| 109 | + mask = (labels >= start) & (labels < end) |
| 110 | + if torch.any(mask): |
| 111 | + idx = (labels[mask] - start).to(torch.long) |
| 112 | + target_logits[mask] = logits_f[mask, idx] |
| 113 | + |
| 114 | + logz = m + torch.log(s) |
| 115 | + logprobs = target_logits - logz |
| 116 | + entropy = logz - (t / s) |
| 117 | + |
| 118 | + # Save for backward (recompute logits per chunk for grad) |
| 119 | + ctx.save_for_backward(hidden, weight, labels, logz) |
| 120 | + ctx.inv_temperature = inv_temperature |
| 121 | + ctx.chunk_size = chunk_size |
| 122 | + ctx.softcap = softcap |
| 123 | + |
| 124 | + # Return fp32 for numerical stability (matching baseline behavior). |
| 125 | + return logprobs, entropy |
| 126 | + |
| 127 | + @staticmethod |
| 128 | + def backward(ctx, grad_logprobs: torch.Tensor, grad_entropy: torch.Tensor | None): |
| 129 | + assert grad_entropy is None or torch.all(grad_entropy == 0.0), ( |
| 130 | + "Backward through entropy is not implemented in GemmaFusedOutputLinear" |
| 131 | + ) |
| 132 | + |
| 133 | + hidden, weight, labels, logz = ctx.saved_tensors |
| 134 | + inv_temperature: torch.Tensor = ctx.inv_temperature # [N] |
| 135 | + chunk_size: int = ctx.chunk_size |
| 136 | + softcap: float = ctx.softcap |
| 137 | + |
| 138 | + n, h = hidden.shape |
| 139 | + vocab = weight.shape[0] |
| 140 | + |
| 141 | + grad_hidden = torch.zeros_like(hidden) |
| 142 | + grad_weight = torch.zeros_like(weight) |
| 143 | + |
| 144 | + g = grad_logprobs.to(torch.float32) # [N] fp32 for stable scaling |
| 145 | + |
| 146 | + inv_t_broadcast = inv_temperature.unsqueeze(-1) # [N, 1] |
| 147 | + |
| 148 | + for start in range(0, vocab, chunk_size): |
| 149 | + end = min(start + chunk_size, vocab) |
| 150 | + w_chunk = weight[start:end] # [C, H] |
| 151 | + |
| 152 | + logits = hidden @ w_chunk.t() # [N, C] (model dtype) |
| 153 | + logits_f = logits.to(torch.float32) # [N, C] fp32 |
| 154 | + |
| 155 | + # Apply final logit softcapping (Gemma2/3) before temperature |
| 156 | + tanh_val = torch.tanh(logits_f / softcap) |
| 157 | + logits_f = softcap * tanh_val |
| 158 | + logits_f = logits_f * inv_t_broadcast # [N, C] fp32 |
| 159 | + |
| 160 | + # p = softmax(logits_f) chunk = exp(logits_f - logz) |
| 161 | + p = torch.exp(logits_f - logz.unsqueeze(-1)) # [N, C] fp32 |
| 162 | + |
| 163 | + # dL/dlogits = g * (1_{label} - p) |
| 164 | + grad_logits = (-g).unsqueeze(-1) * p # [N, C] fp32 |
| 165 | + mask = (labels >= start) & (labels < end) |
| 166 | + if torch.any(mask): |
| 167 | + idx = (labels[mask] - start).to(torch.long) |
| 168 | + grad_logits[mask, idx] += g[mask] |
| 169 | + |
| 170 | + # Chain through temperature scaling |
| 171 | + grad_logits = grad_logits * inv_t_broadcast |
| 172 | + |
| 173 | + # Chain through softcapping: d/dx[c*tanh(x/c)] = 1 - tanh^2(x/c) |
| 174 | + grad_logits = grad_logits * (1 - tanh_val**2) |
| 175 | + |
| 176 | + grad_hidden.add_(grad_logits.to(hidden.dtype) @ w_chunk) |
| 177 | + grad_w_chunk = grad_logits.to(weight.dtype).t() @ hidden # [C, H] |
| 178 | + grad_weight[start:end].add_(grad_w_chunk) |
| 179 | + |
| 180 | + return grad_hidden, grad_weight, None, None, None, None |
| 181 | + |
| 182 | + |
| 183 | +def inject_gemma_lm_head(model: nn.Module, chunk_size: int | None, softcap: float) -> None: |
| 184 | + logger = get_logger() |
| 185 | + logger.info(f"Injecting Gemma LM head with chunk size {chunk_size}, softcap={softcap}") |
| 186 | + |
| 187 | + old_lm_head = model.lm_head |
| 188 | + if chunk_size is not None: |
| 189 | + model.lm_head = GemmaFusedOutputLinear( |
| 190 | + in_features=old_lm_head.in_features, |
| 191 | + out_features=old_lm_head.out_features, |
| 192 | + chunk_size=chunk_size, |
| 193 | + softcap=softcap, |
| 194 | + ) |
| 195 | + else: |
| 196 | + model.lm_head = GemmaVanillaOutputLinear( |
| 197 | + in_features=old_lm_head.in_features, |
| 198 | + out_features=old_lm_head.out_features, |
| 199 | + softcap=softcap, |
| 200 | + ) |
| 201 | + model.lm_head.weight = old_lm_head.weight |
| 202 | + del old_lm_head |
| 203 | + |
| 204 | + _patch_model_forward(model) |
0 commit comments