22from functools import partial
33
44import torch
5+ import torch ._dynamo .config
56import torch .nn .functional as F
67
78
@@ -20,15 +21,18 @@ def forward(
2021 _input ,
2122 weight ,
2223 attention_mask ,
23- rewards ,
24+ advantages ,
2425 bias = None ,
25- num_generations = 4 ,
26- beta = 0.1 ,
27- compiled = True ,
28- use_ref_model = False ,
2926 ref_input = None ,
3027 ref_weight = None ,
3128 ref_bias = None ,
29+ old_per_token_logps = None ,
30+ epsilon_low = 0.2 ,
31+ epsilon_high = 0.2 ,
32+ beta = 0.1 ,
33+ temperature = 1.0 ,
34+ compiled = True ,
35+ use_ref_model = False ,
3236 chunk_size = 1 ,
3337 ):
3438 """Chunked forward pass for RLHF loss computation.
@@ -39,21 +43,20 @@ def forward(
3943 _input: Input tensor
4044 weight: Weight tensor
4145 attention_mask: Attention mask tensor
42- rewards: Rewards tensor
46+ advantages: Advantages tensor
4347 bias: Bias tensor
44- num_generations: Number of generations per prompt
45- beta: Weight for the KL penalty
46- compiled: Whether to use torch compile
47- use_ref_model: Whether to use a reference model
4848 ref_input: Reference model input tensor
4949 ref_weight: Reference model weight tensor
5050 ref_bias: Reference model bias tensor
51+ old_per_token_logps: Old per token log probabilities tensor
52+ epsilon_low: Lower bound for clipping the importance sampling ratio
53+ epsilon_high: Upper bound for clipping the importance sampling ratio
54+ beta: Weight for the KL penalty
55+ temperature: Temperature for the logits
56+ compiled: Whether to use torch compile
57+ use_ref_model: Whether to use a reference model
5158 chunk_size: Size of chunks for processing in other loss modules
5259 """
53- # Save for backward
54- ctx .beta = beta
55- ctx .rewards = rewards
56-
5760 # Initialize accumulators
5861 loss_acc = torch .zeros ((), device = _input .device )
5962 grad_weight = torch .zeros_like (weight ) # [V, H]
@@ -64,43 +67,36 @@ def forward(
6467 # Create a partial function with fixed arguments
6568 compute_loss = partial (
6669 LigerFusedLinearRLHFBase ._compute_chunk_loss ,
67- beta = beta ,
68- use_ref_model = use_ref_model ,
6970 ref_weight = ref_weight ,
7071 ref_bias = ref_bias ,
72+ full_attention_mask = attention_mask ,
73+ epsilon_low = epsilon_low ,
74+ epsilon_high = epsilon_high ,
75+ beta = beta ,
76+ temperature = temperature ,
77+ use_ref_model = use_ref_model ,
7178 rlhf_loss_fn = cls .rlhf_loss_fn ,
7279 )
7380
74- def fused_fwd_bwd (input_chunk , attention_mask_chunk , rewards_chunk , ref_input_chunk ):
81+ def fused_fwd_bwd (input_chunk , attention_mask_chunk , advantages_chunk , ref_input_chunk , old_per_token_logps_chunk ):
7582 """Fused forward and backward for a chunk."""
83+ argnums = (0 , 1 , 4 ) if bias is not None else (0 , 1 )
84+ return torch .func .grad_and_value (compute_loss , argnums = argnums , has_aux = True )(
85+ input_chunk , # arg 0
86+ weight , # arg 1
87+ attention_mask_chunk , # arg 2
88+ advantages_chunk , # arg 3
89+ bias , # arg 4
90+ ref_input_chunk = ref_input_chunk , # arg 5
91+ old_per_token_logps_chunk = old_per_token_logps_chunk , # arg 6
92+ )
93+
94+ def accumulate_chunk (input_chunk , attention_mask_chunk , advantages_chunk , ref_input_chunk = None , old_per_token_logps_chunk = None ):
95+ (chunk_grad_input , chunk_grad_weight , * chunk_grad_bias ), (chunk_loss , chunk_metrics ) = fused_fwd_bwd (
96+ input_chunk , attention_mask_chunk , advantages_chunk , ref_input_chunk , old_per_token_logps_chunk
97+ )
7698 if bias is not None :
77- return torch .func .grad_and_value (compute_loss , argnums = (0 , 1 , 5 ), has_aux = True )(
78- input_chunk , # arg 0
79- weight , # arg 1
80- attention_mask_chunk , # arg 2
81- rewards_chunk , # arg 3
82- ref_input_chunk , # arg 4
83- bias , # arg 5
84- )
85- else :
86- return torch .func .grad_and_value (compute_loss , argnums = (0 , 1 ), has_aux = True )(
87- input_chunk , # arg 0
88- weight , # arg 1
89- attention_mask_chunk , # arg 2
90- rewards_chunk , # arg 3
91- ref_input_chunk , # arg 4
92- )
93-
94- def accumulate_chunk (input_chunk , attention_mask_chunk , rewards_chunk , ref_input_chunk = None ):
95- if bias is not None :
96- (chunk_grad_input , chunk_grad_weight , chunk_grad_bias ), (chunk_loss , chunk_metrics ) = fused_fwd_bwd (
97- input_chunk , attention_mask_chunk , rewards_chunk , ref_input_chunk
98- )
99- grad_bias .add_ (chunk_grad_bias )
100- else :
101- (chunk_grad_input , chunk_grad_weight ), (chunk_loss , chunk_metrics ) = fused_fwd_bwd (
102- input_chunk , attention_mask_chunk , rewards_chunk , ref_input_chunk
103- )
99+ grad_bias .add_ (chunk_grad_bias [0 ])
104100
105101 # Accumulate gradients and loss
106102 grad_weight .add_ (chunk_grad_weight )
@@ -123,28 +119,34 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input
123119 aggregated_metrics [i ].append (metric )
124120
125121 if compiled :
126- accumulate_chunk = torch .compile (accumulate_chunk )
122+ # TODO: Figure out what is better to compile here
123+ # accumulate_chunk = torch.compile(accumulate_chunk)
124+ fused_fwd_bwd = torch .compile (fused_fwd_bwd )
127125
128- # Process input in chunks based on num_generations
129- chunks = max (1 , _input .shape [0 ] // num_generations )
126+ # Process input in chunks based on chunk_size
127+ chunks = max (1 , _input .shape [0 ] // chunk_size )
130128 _input_chunks = torch .chunk (_input , chunks = chunks , dim = 0 )
131129 _attention_mask_chunks = torch .chunk (attention_mask , chunks = chunks , dim = 0 )
132- _rewards_chunks = torch .chunk (rewards , chunks = chunks , dim = 0 )
130+ _advantages_chunks = torch .chunk (advantages , chunks = chunks , dim = 0 )
133131 _ref_input_chunks = torch .chunk (ref_input , chunks = chunks , dim = 0 ) if use_ref_model else [None ] * chunks
132+ _old_per_token_logps_chunks = torch .chunk (old_per_token_logps , chunks = chunks , dim = 0 ) if old_per_token_logps is not None else [None ] * chunks
134133
135- for input_chunk , attention_mask_chunk , rewards_chunk , ref_input_chunk in zip (
136- _input_chunks , _attention_mask_chunks , _rewards_chunks , _ref_input_chunks
134+ for input_chunk , attention_mask_chunk , advantages_chunk , ref_input_chunk , old_per_token_logps_chunk in zip (
135+ _input_chunks , _attention_mask_chunks , _advantages_chunks , _ref_input_chunks , _old_per_token_logps_chunks
137136 ):
138137 # Mark dynamic dimensions
139138 torch ._dynamo .mark_dynamic (input_chunk , 1 )
140139 torch ._dynamo .mark_dynamic (attention_mask_chunk , 1 )
141- if ref_input_chunk is not None :
140+ if use_ref_model :
142141 torch ._dynamo .mark_dynamic (ref_input_chunk , 1 )
142+ else :
143+ ref_input_chunk = None
144+ if old_per_token_logps is not None :
145+ torch ._dynamo .mark_dynamic (old_per_token_logps_chunk , 1 )
146+ else :
147+ old_per_token_logps_chunk = None
143148
144- accumulate_chunk (input_chunk , attention_mask_chunk , rewards_chunk , ref_input_chunk )
145-
146- # Scale accumulated loss by number of chunks since we're averaging
147- loss_acc = loss_acc / chunks
149+ accumulate_chunk (input_chunk , attention_mask_chunk , advantages_chunk , ref_input_chunk , old_per_token_logps_chunk )
148150
149151 # Combine gradients
150152 grad_input = torch .cat (grad_inputs , dim = 0 )
@@ -158,7 +160,7 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input
158160 if isinstance (metric , list ):
159161 final_metrics .append (torch .cat (metric , dim = 0 ))
160162 else :
161- final_metrics .append (metric / chunks )
163+ final_metrics .append (metric )
162164
163165 return loss_acc , tuple (final_metrics )
164166
@@ -167,51 +169,59 @@ def _compute_chunk_loss(
167169 input_chunk ,
168170 weight ,
169171 attention_mask_chunk ,
170- rewards_chunk ,
171- ref_input_chunk = None ,
172+ advantages_chunk ,
172173 bias = None ,
173- beta = 0.1 ,
174- use_ref_model = False ,
174+ ref_input_chunk = None ,
175175 ref_weight = None ,
176176 ref_bias = None ,
177+ old_per_token_logps_chunk = None ,
178+ full_attention_mask = None ,
179+ epsilon_low = 0.2 ,
180+ epsilon_high = 0.2 ,
181+ beta = 0.1 ,
182+ temperature = 1.0 ,
183+ use_ref_model = False ,
177184 rlhf_loss_fn = None ,
178185 ):
179186 """Compute loss for a single chunk."""
180187 # Get policy log probabilities using chunk_forward
181- log_probs , _ , logits_mean = LigerFusedLinearRLHFBase .chunk_forward (input_chunk , weight , bias = bias )
188+ log_probs , _ = LigerFusedLinearRLHFBase .chunk_forward (input_chunk , weight , bias = bias , temperature = temperature )
182189
183190 # Get reference log probabilities if needed
184191 ref_log_probs = None
185192 if use_ref_model and ref_input_chunk is not None :
186193 with torch .no_grad ():
187- ref_log_probs , _ , _ = LigerFusedLinearRLHFBase .chunk_forward (ref_input_chunk , ref_weight , bias = ref_bias )
194+ ref_log_probs , _ = LigerFusedLinearRLHFBase .chunk_forward (ref_input_chunk , ref_weight , bias = ref_bias , temperature = temperature )
188195
189196 # Compute chunk loss and metrics using the provided loss function
190197 chunk_loss , chunk_metrics = rlhf_loss_fn (
191198 log_probs = log_probs ,
192199 attention_mask = attention_mask_chunk ,
193- rewards = rewards_chunk ,
200+ advantages = advantages_chunk ,
201+ full_attention_mask = full_attention_mask ,
194202 ref_log_probs = ref_log_probs ,
203+ old_per_token_logps = old_per_token_logps_chunk ,
204+ epsilon_low = epsilon_low ,
205+ epsilon_high = epsilon_high ,
195206 beta = beta ,
196207 )
197208
198- return chunk_loss , ( logits_mean , * chunk_metrics )
209+ return chunk_loss , chunk_metrics
199210
200211 @staticmethod
201- def chunk_forward (input_chunk , weight , bias = None ):
212+ def chunk_forward (input_chunk , weight , bias = None , temperature = 1.0 ):
202213 """Forward pass computation for a single chunk without explicit reshaping."""
203214 # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
204215 logits = torch .matmul (input_chunk , weight .t ())
205216 if bias is not None :
206217 logits = logits + bias # Broadcasts bias to [B, T, V]
218+ if temperature != 1.0 :
219+ logits = logits / temperature
207220
208221 # Compute log probabilities using softmax over the last dimension
209222 log_probs = F .log_softmax (logits .float (), dim = - 1 )
210223
211- # Monitoring: compute mean of logits
212- batch_size , seq_len , _ = input_chunk .shape
213- logits_mean = logits .sum () / (batch_size * seq_len * weight .shape [0 ])
214- return log_probs , logits , logits_mean
224+ return log_probs , logits
215225
216226 @staticmethod
217227 def backward (ctx , grad_output , * grad_metrics ):
@@ -227,14 +237,17 @@ def backward(ctx, grad_output, *grad_metrics):
227237 grad_input ,
228238 grad_weight ,
229239 None , # grad_attention_mask
230- None , # grad_rewards
240+ None , # grad_advantages
231241 grad_bias ,
232- None , # grad_num_generations
233- None , # grad_beta
234- None , # grad_compiled
235- None , # grad_use_ref_model
236242 None , # grad_ref_input
237243 None , # grad_ref_weight
238244 None , # grad_ref_bias
245+ None , # grad_old_per_token_logps
246+ None , # grad_epsilon_low
247+ None , # grad_epsilon_high
248+ None , # grad_beta
249+ None , # grad_temperature
250+ None , # grad_compiled
251+ None , # grad_use_ref_model
239252 None , # grad_chunk_size
240253 )
0 commit comments