Skip to content

Commit 17cd39b

Browse files
committed
add DAPO
1 parent 2693f1f commit 17cd39b

File tree

3 files changed

+79
-10
lines changed

3 files changed

+79
-10
lines changed

apps/julia-grpo/main.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,75 @@ def simple_grpo_loss(
215215
return loss
216216

217217

218+
def dapo_loss(
219+
logits: torch.Tensor,
220+
response: torch.Tensor,
221+
ref_logprobs: torch.Tensor,
222+
advantages: torch.Tensor,
223+
padding_mask: torch.Tensor,
224+
beta: float = 0.005,
225+
clip_eps_low: float = 0.2,
226+
clip_eps_high: float = 0.28,
227+
) -> torch.Tensor:
228+
"""
229+
DAPO (Direct Alignment Policy Optimization) loss function.
230+
231+
Implements PPO-style clipped objective with KL divergence penalty.
232+
Based on the compute_loss function from old_dapo.py.
233+
234+
Args:
235+
logits: Model output logits [batch_size, seq_len, vocab_size]
236+
response: Response token ids [batch_size, seq_len]
237+
ref_logprobs: Reference model log probabilities [batch_size, seq_len]
238+
advantages: Advantage values [batch_size, 1]
239+
padding_mask: Mask for valid tokens [batch_size, seq_len]
240+
beta: KL divergence coefficient
241+
clip_eps_low: Lower clipping bound for importance sampling ratio
242+
clip_eps_high: Upper clipping bound for importance sampling ratio
243+
244+
Returns:
245+
Scalar loss value
246+
"""
247+
# Compute current action log probabilities
248+
action_log_probs = compute_logprobs(logits, response)
249+
250+
# Compute KL divergence term (k3 in DAPO)
251+
if beta != 0.0:
252+
log_ratio = ref_logprobs - action_log_probs
253+
log_ratio = log_ratio * padding_mask
254+
k3 = log_ratio.exp() - 1 - log_ratio
255+
256+
# Use detached log probs as "old" log probs (for single iteration)
257+
# In multi-iteration setting, these would be passed as input
258+
old_action_log_probs = action_log_probs.detach()
259+
260+
# Compute importance sampling ratio
261+
coef_1 = torch.exp(action_log_probs - old_action_log_probs)
262+
263+
# Clipped importance sampling ratio
264+
coef_2 = torch.clamp(coef_1, 1 - clip_eps_low, 1 + clip_eps_high)
265+
266+
# Compute per-token losses with advantages
267+
# advantages shape: [batch_size, 1], unsqueeze to [batch_size, 1] for broadcasting
268+
per_token_loss1 = coef_1 * advantages
269+
per_token_loss2 = coef_2 * advantages
270+
271+
# Take minimum for clipped objective (negative because we minimize)
272+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
273+
274+
# Apply action mask
275+
per_token_loss = per_token_loss * padding_mask
276+
277+
# Add KL penalty
278+
if beta != 0.0:
279+
per_token_loss = per_token_loss + beta * k3
280+
281+
# Average over tokens and batch
282+
loss = (per_token_loss.sum(dim=1) / padding_mask.sum(dim=1).clamp(min=1.0)).mean()
283+
284+
return loss
285+
286+
218287
@dataclass
219288
class JuliaRewardActor(ForgeActor):
220289
"""Reward actor for Julia code execution using GenericOpenEnvActor.
@@ -550,9 +619,7 @@ async def main(cfg: DictConfig):
550619
) = await asyncio.gather(
551620
JuliaDatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
552621
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
553-
RLTrainer.options(**cfg.actors.trainer).as_actor(
554-
**cfg.trainer, loss=simple_grpo_loss
555-
),
622+
RLTrainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer, loss=dapo_loss),
556623
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
557624
**cfg.replay_buffer, collate=collate
558625
),

apps/julia-grpo/old_dapo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from copy import deepcopy
28
from dataclasses import dataclass
39
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -388,9 +394,7 @@ def compute_loss(self, model, inputs):
388394
coef_2 = torch.clamp(
389395
coef_1, 1 - self.args.clip_eps_low, 1 + self.args.clip_eps_high
390396
)
391-
per_token_loss1 = coef_1 * advantages.unsqueeze(
392-
1
393-
) # 一个序列中每个token的优势是一样的
397+
per_token_loss1 = coef_1 * advantages.unsqueeze(1) # 一个序列中每个token的优势是一样的
394398
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
395399
per_token_loss = -torch.min(
396400
per_token_loss1, per_token_loss2
@@ -518,9 +522,7 @@ def train(self):
518522
if self.update_steps % 10 == 0:
519523
print(f"\n{self.update_steps} 步: === 开始评估模型 ===")
520524
accuracy = self.evaluate(num_samples=100, batch_size=25)
521-
print(
522-
f"第 {self.update_steps} 步: 模型准确率: {accuracy:.2f}"
523-
)
525+
print(f"第 {self.update_steps} 步: 模型准确率: {accuracy:.2f}")
524526
# 将accuracy保存到文件中
525527
accuracy_file_path = os.path.join(
526528
self.args.output_dir, "accuracy_losses.txt"

src/forge/actors/generic_openenv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from core.client_types import StepResult
1313
from core.env_server.types import Action, Observation
1414
from core.http_env_client import HTTPEnvClient
15+
from monarch.actor import endpoint
1516

1617
from forge.controller import ForgeActor
1718
from forge.observability.metrics import record_metric, Reduce
18-
from monarch.actor import endpoint
1919

2020
logger = logging.getLogger(__name__)
2121
logger.setLevel(logging.DEBUG)

0 commit comments

Comments
 (0)