Skip to content

Commit bc068d8

Browse files
committed
luspo fix
1 parent 0248d72 commit bc068d8

File tree

1 file changed

+98
-36
lines changed

1 file changed

+98
-36
lines changed

src/liger_kernel/chunked_loss/fused_linear_ppo.py

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE = 2048
19+
_SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD = 4096
1920

2021

2122
def _next_power_of_two(x):
@@ -115,8 +116,8 @@ def _selective_logprob_forward_torch(hidden, weight, targets, bias=None, tempera
115116
end = min(start + vocab_chunk_size, vocab_size)
116117
chunk_width = end - start
117118
weight_chunk = weight[start:end].to(hidden.dtype)
118-
torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width])
119119
logits_chunk = logits_buf[:, :chunk_width]
120+
torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width])
120121
logits_chunk.copy_(mm_buf[:, :chunk_width])
121122
if bias is not None:
122123
logits_chunk.add_(bias[start:end].to(torch.float32))
@@ -138,6 +139,48 @@ def _selective_logprob_forward_torch(hidden, weight, targets, bias=None, tempera
138139
return target_logit - log_z, log_z
139140

140141

142+
def _selective_logprob_forward_autograd(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=2048):
143+
n_rows, _ = hidden.shape
144+
vocab_size, _ = weight.shape
145+
inv_t = 1.0 / temperature
146+
147+
if vocab_size <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD:
148+
logits = hidden @ weight.to(hidden.dtype).t()
149+
logits = logits.float()
150+
if bias is not None:
151+
logits = logits + bias.to(torch.float32)
152+
logits = logits * inv_t
153+
return torch.log_softmax(logits, dim=-1).gather(-1, targets.unsqueeze(-1)).squeeze(-1)
154+
155+
max_old = torch.full((n_rows,), float("-inf"), device=hidden.device, dtype=torch.float32)
156+
sum_exp = torch.zeros((n_rows,), device=hidden.device, dtype=torch.float32)
157+
target_logit = torch.zeros((n_rows,), device=hidden.device, dtype=torch.float32)
158+
row_idx = torch.arange(n_rows, device=hidden.device)
159+
160+
for start in range(0, vocab_size, vocab_chunk_size):
161+
end = min(start + vocab_chunk_size, vocab_size)
162+
logits_chunk = hidden @ weight[start:end].to(hidden.dtype).t()
163+
logits_chunk = logits_chunk.float()
164+
if bias is not None:
165+
logits_chunk = logits_chunk + bias[start:end].to(torch.float32)
166+
logits_chunk = logits_chunk * inv_t
167+
168+
chunk_max = logits_chunk.amax(dim=-1)
169+
max_new = torch.maximum(max_old, chunk_max)
170+
rescale = torch.exp(max_old - max_new)
171+
chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1))
172+
173+
sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1)
174+
max_old = max_new
175+
176+
in_chunk = (targets >= start) & (targets < end)
177+
local_idx = torch.clamp(targets - start, 0, end - start - 1)
178+
target_logit = target_logit + logits_chunk[row_idx, local_idx] * in_chunk
179+
180+
log_z = max_old + torch.log(sum_exp)
181+
return target_logit - log_z
182+
183+
141184
def _selective_logprob_forward_triton(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=2048):
142185
n_rows, hidden_size = hidden.shape
143186
block_h = min(128, _next_power_of_two(hidden_size))
@@ -171,18 +214,6 @@ def _selective_logprob_forward_triton(hidden, weight, targets, bias=None, temper
171214

172215

173216
def _selective_logprob_forward(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=2048):
174-
if (
175-
_TRITON_AVAILABLE
176-
and hidden.is_cuda
177-
and weight.is_cuda
178-
and targets.is_cuda
179-
and hidden.is_contiguous()
180-
and weight.is_contiguous()
181-
and targets.is_contiguous()
182-
and (bias is None or (bias.is_cuda and bias.is_contiguous()))
183-
):
184-
return _selective_logprob_forward_triton(hidden, weight, targets, bias, temperature, vocab_chunk_size)
185-
186217
return _selective_logprob_forward_torch(hidden, weight, targets, bias, temperature, vocab_chunk_size)
187218

188219

@@ -226,12 +257,12 @@ def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logpr
226257
n_rows, _ = hidden.shape
227258
vocab_size = weight.shape[0]
228259
has_bias = bias is not None
260+
hidden_fp32 = hidden.float()
229261

230262
grad_hidden = torch.zeros(hidden.shape, device=hidden.device, dtype=torch.float32)
231263
grad_weight = torch.zeros(weight.shape, device=weight.device, dtype=torch.float32)
232264
grad_bias = torch.zeros((vocab_size,), device=weight.device, dtype=torch.float32) if has_bias else None
233265

234-
mm_buf = torch.empty((n_rows, vocab_chunk_size), device=hidden.device, dtype=hidden.dtype)
235266
logits_buf = torch.empty((n_rows, vocab_chunk_size), device=hidden.device, dtype=torch.float32)
236267

237268
grad_logprobs = grad_logprobs.to(torch.float32)
@@ -240,11 +271,16 @@ def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logpr
240271
for start in range(0, vocab_size, vocab_chunk_size):
241272
end = min(start + vocab_chunk_size, vocab_size)
242273
chunk_width = end - start
243-
weight_chunk = weight[start:end]
244-
245-
torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width])
274+
weight_chunk = weight[start:end].float()
246275
logits_chunk = logits_buf[:, :chunk_width]
247-
logits_chunk.copy_(mm_buf[:, :chunk_width])
276+
allow_tf32 = torch.backends.cuda.matmul.allow_tf32
277+
if hidden.is_cuda:
278+
torch.backends.cuda.matmul.allow_tf32 = False
279+
try:
280+
torch.mm(hidden_fp32, weight_chunk.t(), out=logits_chunk)
281+
finally:
282+
if hidden.is_cuda:
283+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
248284
if has_bias:
249285
logits_chunk.add_(bias[start:end].to(torch.float32))
250286
logits_chunk.mul_(inv_t)
@@ -257,8 +293,15 @@ def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logpr
257293
grad_logits[row_idx, local_idx] += grad_logprobs * in_chunk
258294
grad_logits.mul_(inv_t)
259295

260-
grad_hidden.add_(grad_logits @ weight_chunk.float())
261-
grad_weight[start:end].add_(grad_logits.t() @ hidden.float())
296+
allow_tf32 = torch.backends.cuda.matmul.allow_tf32
297+
if hidden.is_cuda:
298+
torch.backends.cuda.matmul.allow_tf32 = False
299+
try:
300+
grad_hidden.add_(grad_logits @ weight_chunk)
301+
grad_weight[start:end].add_(grad_logits.t() @ hidden_fp32)
302+
finally:
303+
if hidden.is_cuda:
304+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
262305
if has_bias:
263306
grad_bias[start:end].add_(grad_logits.sum(dim=0))
264307

@@ -385,7 +428,8 @@ def forward(
385428
delta=delta,
386429
use_bias_correction_kl=use_bias_correction_kl,
387430
)
388-
compiled_compute_loss = torch.compile(compute_loss) if compiled else compute_loss
431+
use_compiled_compute_loss = compiled and weight.shape[0] > _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD
432+
compiled_compute_loss = torch.compile(compute_loss) if use_compiled_compute_loss else compute_loss
389433

390434
def fused_fwd_bwd(
391435
input_chunk,
@@ -463,7 +507,10 @@ def accumulate_chunk(
463507
aggregated_metrics[i].append(metric)
464508

465509
# Process input in chunks based on chunk_size
466-
chunks = max(1, _input.shape[0] // chunk_size)
510+
if weight.shape[0] <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD:
511+
chunks = 1
512+
else:
513+
chunks = max(1, _input.shape[0] // chunk_size)
467514
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
468515
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
469516
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
@@ -598,13 +645,24 @@ def _compute_chunk_loss(
598645

599646
if use_ref_model and ref_per_token_logps_chunk is None:
600647
with torch.no_grad():
601-
ref_per_token_logps_chunk = LigerFusedLinearPPOBase.chunk_forward(
602-
ref_input_chunk,
603-
ref_weight,
604-
selected_token_ids_chunk,
605-
bias=ref_bias,
606-
temperature=temperature,
607-
)
648+
if ref_weight.shape[0] <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD:
649+
ref_logits = ref_input_chunk @ ref_weight.t()
650+
if ref_bias is not None:
651+
ref_logits = ref_logits + ref_bias.float()
652+
ref_logits = ref_logits.float()
653+
if temperature != 1.0:
654+
ref_logits = ref_logits / temperature
655+
ref_per_token_logps_chunk = torch.log_softmax(ref_logits, dim=-1).gather(
656+
-1, selected_token_ids_chunk.unsqueeze(-1)
657+
).squeeze(-1)
658+
else:
659+
ref_per_token_logps_chunk = LigerFusedLinearPPOBase.chunk_forward(
660+
ref_input_chunk,
661+
ref_weight,
662+
selected_token_ids_chunk,
663+
bias=ref_bias,
664+
temperature=temperature,
665+
)
608666

609667
# Compute chunk loss and metrics using the provided loss function
610668
chunk_loss, chunk_metrics = ppo_loss_fn(
@@ -632,16 +690,20 @@ def _compute_chunk_loss(
632690
@staticmethod
633691
def chunk_forward(input_chunk, weight, selected_token_ids, bias=None, temperature=1.0):
634692
"""Compute selected-token log probabilities without materializing full vocab logits."""
693+
if weight.shape[0] <= _SELECTIVE_LOGPROB_EXACT_VOCAB_THRESHOLD:
694+
logits = input_chunk @ weight.t()
695+
if bias is not None:
696+
logits = logits + bias
697+
logits = logits.float()
698+
if temperature != 1.0:
699+
logits = logits / temperature
700+
return torch.log_softmax(logits, dim=-1).gather(-1, selected_token_ids.unsqueeze(-1)).squeeze(-1)
701+
635702
batch_size, seq_len, hidden_size = input_chunk.shape
636703
hidden = input_chunk.reshape(batch_size * seq_len, hidden_size).contiguous()
637704
targets = selected_token_ids.reshape(batch_size * seq_len).contiguous()
638-
per_token_logps = _ChunkedSelectiveLogProbFunction.apply(
639-
hidden,
640-
weight,
641-
targets,
642-
bias,
643-
temperature,
644-
_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE,
705+
per_token_logps = _selective_logprob_forward_autograd(
706+
hidden, weight, targets, bias, temperature, _SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE
645707
)
646708
return per_token_logps.reshape(batch_size, seq_len)
647709

0 commit comments

Comments
 (0)