Skip to content

Commit c967ee1

Browse files
committed
refactor GRPO
1 parent 812b050 commit c967ee1

File tree

3 files changed

+266
-216
lines changed

3 files changed

+266
-216
lines changed

src/liger_kernel/chunked_loss/fused_linear_rlhf.py

Lines changed: 87 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import partial
33

44
import torch
5+
import torch._dynamo.config
56
import torch.nn.functional as F
67

78

@@ -20,15 +21,18 @@ def forward(
2021
_input,
2122
weight,
2223
attention_mask,
23-
rewards,
24+
advantages,
2425
bias=None,
25-
num_generations=4,
26-
beta=0.1,
27-
compiled=True,
28-
use_ref_model=False,
2926
ref_input=None,
3027
ref_weight=None,
3128
ref_bias=None,
29+
old_per_token_logps=None,
30+
epsilon_low=0.2,
31+
epsilon_high=0.2,
32+
beta=0.1,
33+
temperature=1.0,
34+
compiled=True,
35+
use_ref_model=False,
3236
chunk_size=1,
3337
):
3438
"""Chunked forward pass for RLHF loss computation.
@@ -39,21 +43,20 @@ def forward(
3943
_input: Input tensor
4044
weight: Weight tensor
4145
attention_mask: Attention mask tensor
42-
rewards: Rewards tensor
46+
advantages: Advantages tensor
4347
bias: Bias tensor
44-
num_generations: Number of generations per prompt
45-
beta: Weight for the KL penalty
46-
compiled: Whether to use torch compile
47-
use_ref_model: Whether to use a reference model
4848
ref_input: Reference model input tensor
4949
ref_weight: Reference model weight tensor
5050
ref_bias: Reference model bias tensor
51+
old_per_token_logps: Old per token log probabilities tensor
52+
epsilon_low: Lower bound for clipping the importance sampling ratio
53+
epsilon_high: Upper bound for clipping the importance sampling ratio
54+
beta: Weight for the KL penalty
55+
temperature: Temperature for the logits
56+
compiled: Whether to use torch compile
57+
use_ref_model: Whether to use a reference model
5158
chunk_size: Size of chunks for processing in other loss modules
5259
"""
53-
# Save for backward
54-
ctx.beta = beta
55-
ctx.rewards = rewards
56-
5760
# Initialize accumulators
5861
loss_acc = torch.zeros((), device=_input.device)
5962
grad_weight = torch.zeros_like(weight) # [V, H]
@@ -64,43 +67,36 @@ def forward(
6467
# Create a partial function with fixed arguments
6568
compute_loss = partial(
6669
LigerFusedLinearRLHFBase._compute_chunk_loss,
67-
beta=beta,
68-
use_ref_model=use_ref_model,
6970
ref_weight=ref_weight,
7071
ref_bias=ref_bias,
72+
full_attention_mask=attention_mask,
73+
epsilon_low=epsilon_low,
74+
epsilon_high=epsilon_high,
75+
beta=beta,
76+
temperature=temperature,
77+
use_ref_model=use_ref_model,
7178
rlhf_loss_fn=cls.rlhf_loss_fn,
7279
)
7380

74-
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
81+
def fused_fwd_bwd(input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk):
7582
"""Fused forward and backward for a chunk."""
83+
argnums = (0, 1, 4) if bias is not None else (0, 1)
84+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
85+
input_chunk, # arg 0
86+
weight, # arg 1
87+
attention_mask_chunk, # arg 2
88+
advantages_chunk, # arg 3
89+
bias, # arg 4
90+
ref_input_chunk=ref_input_chunk, # arg 5
91+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 6
92+
)
93+
94+
def accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk=None, old_per_token_logps_chunk=None):
95+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
96+
input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk
97+
)
7698
if bias is not None:
77-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
78-
input_chunk, # arg 0
79-
weight, # arg 1
80-
attention_mask_chunk, # arg 2
81-
rewards_chunk, # arg 3
82-
ref_input_chunk, # arg 4
83-
bias, # arg 5
84-
)
85-
else:
86-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
87-
input_chunk, # arg 0
88-
weight, # arg 1
89-
attention_mask_chunk, # arg 2
90-
rewards_chunk, # arg 3
91-
ref_input_chunk, # arg 4
92-
)
93-
94-
def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
95-
if bias is not None:
96-
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
97-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
98-
)
99-
grad_bias.add_(chunk_grad_bias)
100-
else:
101-
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
102-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
103-
)
99+
grad_bias.add_(chunk_grad_bias[0])
104100

105101
# Accumulate gradients and loss
106102
grad_weight.add_(chunk_grad_weight)
@@ -123,28 +119,34 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input
123119
aggregated_metrics[i].append(metric)
124120

125121
if compiled:
126-
accumulate_chunk = torch.compile(accumulate_chunk)
122+
# TODO: Figure out what is better to compile here
123+
# accumulate_chunk = torch.compile(accumulate_chunk)
124+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
127125

128-
# Process input in chunks based on num_generations
129-
chunks = max(1, _input.shape[0] // num_generations)
126+
# Process input in chunks based on chunk_size
127+
chunks = max(1, _input.shape[0] // chunk_size)
130128
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
131129
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
132-
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
130+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
133131
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
132+
_old_per_token_logps_chunks = torch.chunk(old_per_token_logps, chunks=chunks, dim=0) if old_per_token_logps is not None else [None] * chunks
134133

135-
for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
136-
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
134+
for input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk in zip(
135+
_input_chunks, _attention_mask_chunks, _advantages_chunks, _ref_input_chunks, _old_per_token_logps_chunks
137136
):
138137
# Mark dynamic dimensions
139138
torch._dynamo.mark_dynamic(input_chunk, 1)
140139
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
141-
if ref_input_chunk is not None:
140+
if use_ref_model:
142141
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
142+
else:
143+
ref_input_chunk = None
144+
if old_per_token_logps is not None:
145+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
146+
else:
147+
old_per_token_logps_chunk = None
143148

144-
accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
145-
146-
# Scale accumulated loss by number of chunks since we're averaging
147-
loss_acc = loss_acc / chunks
149+
accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk)
148150

149151
# Combine gradients
150152
grad_input = torch.cat(grad_inputs, dim=0)
@@ -158,7 +160,7 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input
158160
if isinstance(metric, list):
159161
final_metrics.append(torch.cat(metric, dim=0))
160162
else:
161-
final_metrics.append(metric / chunks)
163+
final_metrics.append(metric)
162164

163165
return loss_acc, tuple(final_metrics)
164166

@@ -167,51 +169,59 @@ def _compute_chunk_loss(
167169
input_chunk,
168170
weight,
169171
attention_mask_chunk,
170-
rewards_chunk,
171-
ref_input_chunk=None,
172+
advantages_chunk,
172173
bias=None,
173-
beta=0.1,
174-
use_ref_model=False,
174+
ref_input_chunk=None,
175175
ref_weight=None,
176176
ref_bias=None,
177+
old_per_token_logps_chunk=None,
178+
full_attention_mask=None,
179+
epsilon_low=0.2,
180+
epsilon_high=0.2,
181+
beta=0.1,
182+
temperature=1.0,
183+
use_ref_model=False,
177184
rlhf_loss_fn=None,
178185
):
179186
"""Compute loss for a single chunk."""
180187
# Get policy log probabilities using chunk_forward
181-
log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
188+
log_probs, _ = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
182189

183190
# Get reference log probabilities if needed
184191
ref_log_probs = None
185192
if use_ref_model and ref_input_chunk is not None:
186193
with torch.no_grad():
187-
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
194+
ref_log_probs, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature)
188195

189196
# Compute chunk loss and metrics using the provided loss function
190197
chunk_loss, chunk_metrics = rlhf_loss_fn(
191198
log_probs=log_probs,
192199
attention_mask=attention_mask_chunk,
193-
rewards=rewards_chunk,
200+
advantages=advantages_chunk,
201+
full_attention_mask=full_attention_mask,
194202
ref_log_probs=ref_log_probs,
203+
old_per_token_logps=old_per_token_logps_chunk,
204+
epsilon_low=epsilon_low,
205+
epsilon_high=epsilon_high,
195206
beta=beta,
196207
)
197208

198-
return chunk_loss, (logits_mean, *chunk_metrics)
209+
return chunk_loss, chunk_metrics
199210

200211
@staticmethod
201-
def chunk_forward(input_chunk, weight, bias=None):
212+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
202213
"""Forward pass computation for a single chunk without explicit reshaping."""
203214
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
204215
logits = torch.matmul(input_chunk, weight.t())
205216
if bias is not None:
206217
logits = logits + bias # Broadcasts bias to [B, T, V]
218+
if temperature != 1.0:
219+
logits = logits / temperature
207220

208221
# Compute log probabilities using softmax over the last dimension
209222
log_probs = F.log_softmax(logits.float(), dim=-1)
210223

211-
# Monitoring: compute mean of logits
212-
batch_size, seq_len, _ = input_chunk.shape
213-
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
214-
return log_probs, logits, logits_mean
224+
return log_probs, logits
215225

216226
@staticmethod
217227
def backward(ctx, grad_output, *grad_metrics):
@@ -227,14 +237,17 @@ def backward(ctx, grad_output, *grad_metrics):
227237
grad_input,
228238
grad_weight,
229239
None, # grad_attention_mask
230-
None, # grad_rewards
240+
None, # grad_advantages
231241
grad_bias,
232-
None, # grad_num_generations
233-
None, # grad_beta
234-
None, # grad_compiled
235-
None, # grad_use_ref_model
236242
None, # grad_ref_input
237243
None, # grad_ref_weight
238244
None, # grad_ref_bias
245+
None, # grad_old_per_token_logps
246+
None, # grad_epsilon_low
247+
None, # grad_epsilon_high
248+
None, # grad_beta
249+
None, # grad_temperature
250+
None, # grad_compiled
251+
None, # grad_use_ref_model
239252
None, # grad_chunk_size
240253
)

0 commit comments

Comments
 (0)