Skip to content

Commit 96a0820

Browse files
authored
Merge branch 'main' into dev/scalable_grpo
2 parents 6400b5d + 54c2d72 commit 96a0820

File tree

20 files changed

+695
-65
lines changed

20 files changed

+695
-65
lines changed

docs/sphinx_doc/source/tutorial/develop_workflow.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,13 +513,14 @@ Here, `<config_file_path>` is the path to a YAML configuration file, which shoul
513513
Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:
514514

515515
```bash
516-
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
516+
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
517517
```
518518

519519
- `<config_file_path>`: Path to the YAML configuration file, usually the same as used for starting the inference model.
520-
- `<output_file_path>`: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.
520+
- `<output_dir>`: Directory to save the debug output. If not specified, the output will be saved to the `debug_output` in the current working directory.
521521
- `<plugin_dir>` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.
522+
- `--enable-profiling` (optional): Enable performance profiling using [viztracer](https://github.com/gaogaotiantian/viztracer).
522523

523-
During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection.
524+
During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection and the output experiences will be saved to the `<output_dir>/experiences.db` file.
524525

525526
When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal.

docs/sphinx_doc/source_zh/tutorial/develop_workflow.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,13 +509,14 @@ trinity debug --config <config_file_path> --module inference_model
509509
模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试:
510510

511511
```bash
512-
trinity debug --config <config_file_path> --module workflow --output-file <output_file_path> --plugin-dir <plugin_dir>
512+
trinity debug --config <config_file_path> --module workflow --output-dir <output_dir> --plugin-dir <plugin_dir> --enable-profiling
513513
```
514514

515-
- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
516-
- `output_file_path`:性能分析结果输出路径。调试模式会使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。
517-
- `plugin_dir`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。
515+
- `<config_file_path>`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
516+
- `<output_dir>`:调试输出保存目录。如果未指定,调试输出将保存在当前工作目录下的 `debug_output` 目录中。
517+
- `<plugin_dir>`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。
518+
- `--enable-profiling`(可选):启用性能分析,使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析。
518519

519-
调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果
520+
调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端以供观察和查看,同时产出的 Experience 会保存到 `<output_dir>/experiences.db` 数据库中
520521

521522
调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。

tests/algorithm/kl_fn_test.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# -*- coding: utf-8 -*-
2+
"""Test for KL functions"""
3+
4+
import unittest
5+
6+
import torch
7+
8+
from trinity.algorithm.kl_fn.kl_fn import KL_FN
9+
10+
11+
class KLFnTest(unittest.TestCase):
12+
def setUp(self):
13+
seed = 42
14+
torch.manual_seed(seed)
15+
torch.cuda.manual_seed(seed)
16+
torch.cuda.manual_seed_all(seed)
17+
torch.backends.cudnn.deterministic = True
18+
torch.backends.cudnn.benchmark = False
19+
20+
shape = (4, 10)
21+
self.logprob = 2 * torch.rand(shape) - 1
22+
self.ref_logprob = 2 * torch.rand(shape) - 1
23+
self.old_logprob = 2 * torch.rand(shape) - 1
24+
self.response_mask = torch.rand(shape) > 0.5
25+
26+
def test_k1_kl_fn(self):
27+
kl_fn_cls = KL_FN.get("k1")
28+
kl_fn = kl_fn_cls(kl_coef=0.01)
29+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
30+
expected_kl = self.logprob - self.ref_logprob
31+
self.assertTrue(torch.allclose(kl, expected_kl))
32+
33+
def test_k2_kl_fn(self):
34+
kl_fn_cls = KL_FN.get("k2")
35+
kl_fn = kl_fn_cls(kl_coef=0.01)
36+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
37+
expected_kl = (self.logprob - self.ref_logprob).square() * 0.5
38+
self.assertTrue(torch.allclose(kl, expected_kl))
39+
40+
def test_k3_kl_fn(self):
41+
kl_fn_cls = KL_FN.get("k3")
42+
kl_fn = kl_fn_cls(kl_coef=0.01)
43+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
44+
logr = self.ref_logprob - self.logprob
45+
expected_kl = logr.exp() - 1 - logr
46+
self.assertTrue(torch.allclose(kl, expected_kl))
47+
48+
def test_abs_kl_fn(self):
49+
kl_fn_cls = KL_FN.get("abs")
50+
kl_fn = kl_fn_cls(kl_coef=0.01)
51+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
52+
expected_kl = torch.abs(self.logprob - self.ref_logprob)
53+
self.assertTrue(torch.allclose(kl, expected_kl))
54+
55+
def test_low_var_kl_fn(self):
56+
kl_fn_cls = KL_FN.get("low_var_kl")
57+
kl_fn = kl_fn_cls(kl_coef=0.01)
58+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
59+
kl_intermediate = self.ref_logprob - self.logprob
60+
kl_intermediate = torch.clamp(kl_intermediate, min=-20, max=20)
61+
ratio = torch.exp(kl_intermediate)
62+
expected_kl = torch.clamp((ratio - kl_intermediate - 1).contiguous(), min=-10, max=10)
63+
self.assertTrue(torch.allclose(kl, expected_kl))
64+
65+
def test_dummy_kl_fn(self):
66+
kl_fn_cls = KL_FN.get("none")
67+
kl_fn = kl_fn_cls(kl_coef=0.01)
68+
kl = kl_fn.calculate_kl(self.logprob, self.ref_logprob)
69+
expected_kl = torch.zeros_like(self.logprob)
70+
self.assertTrue(torch.allclose(kl, expected_kl))
71+
72+
def test_corrected_k3_fallback(self):
73+
k3_fn = KL_FN.get("k3")(kl_coef=0.01)
74+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
75+
kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob)
76+
kl_corrected_no_old = corrected_k3_fn.calculate_kl(
77+
self.logprob, self.ref_logprob, old_logprob=None
78+
)
79+
self.assertTrue(torch.allclose(kl_standard, kl_corrected_no_old))
80+
81+
def test_corrected_k3_with_old_logprob(self):
82+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
83+
kl_corrected = corrected_k3_fn.calculate_kl(
84+
self.logprob, self.ref_logprob, self.old_logprob
85+
)
86+
logr = self.ref_logprob - self.logprob
87+
kl_standard = logr.exp() - 1 - logr
88+
log_ratio_is = self.logprob - self.old_logprob
89+
ratio_is = log_ratio_is.exp()
90+
ratio_is = torch.clamp(ratio_is, min=0.0, max=2.0)
91+
expected_kl = ratio_is * kl_standard
92+
self.assertTrue(torch.allclose(kl_corrected, expected_kl))
93+
94+
def test_corrected_k3_same_policy(self):
95+
k3_fn = KL_FN.get("k3")(kl_coef=0.01)
96+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
97+
kl_standard = k3_fn.calculate_kl(self.logprob, self.ref_logprob)
98+
kl_corrected = corrected_k3_fn.calculate_kl(self.logprob, self.ref_logprob, self.logprob)
99+
self.assertTrue(torch.allclose(kl_standard, kl_corrected, rtol=1e-4, atol=1e-6))
100+
101+
def test_corrected_k3_loss(self):
102+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
103+
kl_loss, metrics = corrected_k3_fn.calculate_kl_loss(
104+
logprob=self.logprob,
105+
ref_logprob=self.ref_logprob,
106+
response_mask=self.response_mask,
107+
loss_agg_mode="token-mean",
108+
old_logprob=self.old_logprob,
109+
)
110+
self.assertEqual(kl_loss.dim(), 0)
111+
self.assertIn("kl_loss", metrics)
112+
self.assertIn("kl_coef", metrics)
113+
self.assertEqual(metrics["kl_coef"], 0.01)
114+
115+
def test_kl_loss_aggregation_modes(self):
116+
corrected_k3_fn = KL_FN.get("corrected_k3")(kl_coef=0.01)
117+
kl_loss_mean, _ = corrected_k3_fn.calculate_kl_loss(
118+
logprob=self.logprob,
119+
ref_logprob=self.ref_logprob,
120+
response_mask=self.response_mask,
121+
loss_agg_mode="token-mean",
122+
old_logprob=self.old_logprob,
123+
)
124+
kl_loss_sum, _ = corrected_k3_fn.calculate_kl_loss(
125+
logprob=self.logprob,
126+
ref_logprob=self.ref_logprob,
127+
response_mask=self.response_mask,
128+
loss_agg_mode="seq-mean-token-sum",
129+
old_logprob=self.old_logprob,
130+
)
131+
self.assertGreater(kl_loss_sum.item(), kl_loss_mean.item())

tests/algorithm/policy_loss_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,20 @@ def test_ppo_policy_loss_with_sequence_masking(self):
142142
self.assertTrue(
143143
torch.allclose(torch.tensor(metrics["seq_mask/mean_sequence_kl"]), mean_sequence_kl)
144144
)
145+
146+
def test_sapo_policy_loss(self):
147+
policy_loss_fn_cls = POLICY_LOSS_FN.get("sapo")
148+
policy_loss_fn_args = policy_loss_fn_cls.default_args()
149+
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
150+
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
151+
sapo_loss = torch.tensor(-0.05128994956612587)
152+
ppo_kl = torch.tensor(-0.21663446724414825)
153+
avg_soft_gate = torch.tensor(2.3191137313842773)
154+
avg_ratio = torch.tensor(1.630766749382019)
155+
pos_adv_frac = torch.tensor(0.3958333432674408)
156+
self.assertTrue(torch.allclose(loss, sapo_loss))
157+
self.assertTrue(torch.allclose(torch.tensor(metrics["sapo_loss"]), sapo_loss))
158+
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
159+
self.assertTrue(torch.allclose(torch.tensor(metrics["avg_soft_gate"]), avg_soft_gate))
160+
self.assertTrue(torch.allclose(torch.tensor(metrics["avg_ratio"]), avg_ratio))
161+
self.assertTrue(torch.allclose(torch.tensor(metrics["pos_adv_frac"]), pos_adv_frac))

tests/cli/launcher_test.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setUp(self):
4141

4242
def tearDown(self):
4343
sys.argv = self._orig_argv
44+
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
4445

4546
@mock.patch("trinity.cli.launcher.serve")
4647
@mock.patch("trinity.cli.launcher.explore")
@@ -254,31 +255,79 @@ def test_multi_stage_run(
254255
@mock.patch("trinity.cli.launcher.load_config")
255256
def test_debug_mode(self, mock_load):
256257
process = multiprocessing.Process(target=debug_inference_model_process)
257-
process.start()
258-
time.sleep(15) # wait for the model to be created
259-
for _ in range(10):
260-
try:
261-
get_debug_inference_model(self.config)
262-
break
263-
except Exception:
264-
time.sleep(3)
265-
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
266-
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
267-
mock_load.return_value = self.config
268-
with mock.patch(
269-
"argparse.ArgumentParser.parse_args",
270-
return_value=mock.Mock(
271-
command="debug",
272-
config="dummy.yaml",
273-
module="workflow",
274-
output_file=output_file,
275-
plugin_dir="",
276-
),
277-
):
278-
launcher.main()
279-
process.join(timeout=10)
280-
process.terminate()
281-
self.assertTrue(os.path.exists(output_file))
258+
try:
259+
process.start()
260+
time.sleep(15) # wait for the model to be created
261+
for _ in range(10):
262+
try:
263+
get_debug_inference_model(self.config)
264+
break
265+
except Exception:
266+
time.sleep(3)
267+
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
268+
output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output")
269+
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
270+
mock_load.return_value = self.config
271+
with mock.patch(
272+
"argparse.ArgumentParser.parse_args",
273+
return_value=mock.Mock(
274+
command="debug",
275+
config="dummy.yaml",
276+
module="workflow",
277+
enable_profiling=True,
278+
output_dir=output_dir,
279+
output_file=output_file,
280+
plugin_dir="",
281+
),
282+
):
283+
launcher.main()
284+
285+
self.assertFalse(os.path.exists(output_file))
286+
self.assertTrue(os.path.exists(output_dir))
287+
self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html")))
288+
self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db")))
289+
# add a dummy file to test overwrite behavior
290+
with open(os.path.join(output_dir, "dummy.txt"), "w") as f:
291+
f.write("not empty")
292+
293+
with mock.patch(
294+
"argparse.ArgumentParser.parse_args",
295+
return_value=mock.Mock(
296+
command="debug",
297+
config="dummy.yaml",
298+
module="workflow",
299+
enable_profiling=False,
300+
output_dir=output_dir,
301+
output_file=output_file,
302+
plugin_dir="",
303+
),
304+
):
305+
launcher.main()
306+
307+
self.assertFalse(os.path.exists(output_file))
308+
# test the original files are not overwritten
309+
self.assertTrue(os.path.exists(output_dir))
310+
self.assertTrue(os.path.exists(os.path.join(output_dir, "dummy.txt")))
311+
dirs = os.listdir(self.config.checkpoint_job_dir)
312+
target_output_dir = [d for d in dirs if d.startswith("debug_output_")]
313+
self.assertEqual(len(target_output_dir), 1)
314+
self.assertFalse(
315+
os.path.exists(
316+
os.path.join(
317+
self.config.checkpoint_job_dir, target_output_dir[0], "profiling.html"
318+
)
319+
)
320+
)
321+
self.assertTrue(
322+
os.path.exists(
323+
os.path.join(
324+
self.config.checkpoint_job_dir, target_output_dir[0], "experiences.db"
325+
)
326+
)
327+
)
328+
finally:
329+
process.join(timeout=10)
330+
process.terminate()
282331

283332

284333
def debug_inference_model_process():
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""A file contains some dependencies."""
2+
3+
DEPENDENCY_VALUE = 0
4+
5+
6+
def dependency_func():
7+
return "0"

tests/utils/plugins/main.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from tests.utils.plugins.dependencies import DEPENDENCY_VALUE, dependency_func
2+
from trinity.common.workflows.workflow import Workflow
3+
4+
5+
class MainDummyWorkflow(Workflow):
6+
def __init__(self, *, task, model, auxiliary_models=None):
7+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
8+
9+
@property
10+
def repeatable(self):
11+
return True
12+
13+
def set_repeat_times(self, repeat_times, run_id_base):
14+
pass
15+
16+
def run(self) -> list:
17+
return [DEPENDENCY_VALUE, dependency_func()]

tests/utils/registry_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import unittest
2+
3+
import ray
4+
5+
6+
class ImportUtils:
7+
def run(self):
8+
from trinity.common.workflows import WORKFLOWS, Workflow
9+
10+
workflow_cls = WORKFLOWS.get("tests.utils.plugins.main.MainDummyWorkflow")
11+
assert issubclass(workflow_cls, Workflow)
12+
workflow = workflow_cls(task=None, model=None)
13+
res = workflow.run()
14+
assert res[0] == 0
15+
assert res[1] == "0"
16+
17+
18+
class TestRegistry(unittest.TestCase):
19+
def setUp(self):
20+
ray.init(ignore_reinit_error=True)
21+
22+
def tearDown(self):
23+
ray.shutdown()
24+
25+
def test_dynamic_import(self):
26+
# test local import
27+
ImportUtils().run()
28+
# test remote import
29+
ray.get(ray.remote(ImportUtils).remote().run.remote())

trinity/algorithm/algorithm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,33 @@ def default_config(cls) -> Dict:
250250
}
251251

252252

253+
@ALGORITHM_TYPE.register_module("sapo")
254+
class SAPOAlgorithm(AlgorithmType):
255+
"""SAPO (Soft Adaptive Policy Optimization) algorithm.
256+
257+
SAPO uses a smooth, temperature-controlled soft gate instead of hard clipping
258+
to stabilize training while maintaining effective learning.
259+
"""
260+
261+
use_critic: bool = False
262+
use_reference: bool = True
263+
compute_advantage_in_trainer: bool = False
264+
can_balance_batch: bool = True
265+
schema: str = "experience"
266+
267+
@classmethod
268+
def default_config(cls) -> Dict:
269+
return {
270+
"repeat_times": 2,
271+
"advantage_fn": "grpo",
272+
"sample_strategy": "default",
273+
"policy_loss_fn": "sapo",
274+
"kl_penalty_fn": "none",
275+
"kl_loss_fn": "k2",
276+
"entropy_loss_fn": "default",
277+
}
278+
279+
253280
@ALGORITHM_TYPE.register_module("mix")
254281
class MIXAlgorithm(AlgorithmType):
255282
"""MIX algorithm."""

0 commit comments

Comments
 (0)