Skip to content

Commit faed5ef

Browse files
committed
Merge branch 'main' of github.com:modelscope/Trinity-RFT into fix/rename_warmup_style
2 parents d67a913 + 530c877 commit faed5ef

File tree

8 files changed

+102
-7
lines changed

8 files changed

+102
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob
3232

3333
## 🚀 News
3434

35+
* [2026-01] [[Release Notes]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.1) Trinity-RFT v0.4.1 released: upgraded verl to v0.7.0, Tinker backend supports OpenAI API, bug fixes.
3536
* [2026-01] Introducing [R3L](https://github.com/shiweijiezero/R3L): a systematic reflect-then-retry RL mechanism with efficient language-guided exploration and stable off-policy learning ([paper](https://arxiv.org/abs/2601.03715)).
3637
* [2025-12] [[Release Notes]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.0) Trinity-RFT v0.4.0 released: added [Tinker](https://thinkingmachines.ai/tinker/) backend for users **without GPUs**, add more benchmarks, enhance online RL and more.
3738
* [2025-12] Trinity-RFT powers the medical and health business of "Taobao Shangou", enabling the AI agent to understand vague symptoms, proactively ask follow-up questions, and provide precise recommendations ([News](https://tech.china.com.cn/sx/20251201/411376.shtml)).

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
4141

4242
## 🚀 新闻
4343

44+
* [2026-01] [[发布说明]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.1) Trinity-RFT v0.4.1 发布:升级 verl 至 v0.7.0,Tinker 后端支持 OpenAI API,修复若干 Bug。
4445
* [2026-01] 推出 [R3L](https://github.com/shiweijiezero/R3L):基于反思-重试的强化学习机制,由自然语言反馈引导高效探索,并达成稳定的 off-policy 学习([论文](https://arxiv.org/abs/2601.03715))。
4546
* [2025-12] [[发布说明]](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.4.0) Trinity-RFT v0.4.0 发布:新增[Tinker](https://thinkingmachines.ai/tinker/) 后端以支持在 **无 GPU** 的设备上训练,增加更多基准测试,增强在线 RL 等功能。
4647
* [2025-12] Trinity-RFT 已支持 [tinker](https://thinkingmachines.ai/tinker/) 训练后端,可在**无 GPU 的设备**上进行模型训练。

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ algorithm:
111111
- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`. Some algorithms such as GRPO and OPMD require `repeat_times` > 1.
112112
- `optimizer`: Optimizer configuration for actor.
113113
- `lr`: Learning rate for actor.
114-
- `warmup_style`: Deprecated, use `lr_scheduler_type` instead.
115-
- `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `consine`.
114+
- `warmup_style`: Deprecated, use `lr_scheduler_type` instead. We will remove this field in future versions.
115+
- `lr_scheduler_type`: Learning rate scheduler type for actor model. Default is `constant`. Supported types: `constant`, `cosine`.
116116
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`.
117117
- `advantage_fn`: The advantage function used for computing advantages.
118118
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ algorithm:
111111
- `repeat_times`: 每个任务重复的次数。默认为 `1`。在 `dpo` 中自动设为 `2`。某些算法如 GRPO 和 OPMD 要求 `repeat_times` > 1。
112112
- `optimizer`: Actor 优化器的参数。
113113
- `lr`: 优化器的学习率。
114-
- `warmup_style`:已弃用,请改用 `lr_scheduler_type`。
114+
- `warmup_style`:已弃用,请改用 `lr_scheduler_type`。该域将会在未来版本中移除。
115115
- `lr_scheduler_type`:Actor 模型的学习率调度器类型。默认值为 `constant`。支持类型:`constant`、`cosine`。
116116
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。
117117
- `advantage_fn`: 用于计算优势值的函数。

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "trinity-rft"
7-
version = "0.4.0"
7+
version = "0.4.1"
88
authors = [
99
{name="Trinity-RFT Team", email="trinity-rft@outlook.com"},
1010
]
@@ -88,7 +88,7 @@ tinker = [
8888
]
8989

9090
doc = [
91-
"sphinx",
91+
"sphinx<9.0.0",
9292
"sphinx-autobuild",
9393
"sphinx-book-theme",
9494
"myst-parser",

tests/common/vllm_test.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,6 @@ def setUp(self):
12701270
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
12711271
self.config.explorer.rollout_model.enable_openai_api = True
12721272
self.config.explorer.rollout_model.enable_lora = True
1273-
self.config.explorer.rollout_model.enable_runtime_lora_updating = True
12741273

12751274
self.config.check_and_update()
12761275
self.engines, self.auxiliary_engines = create_inference_models(self.config)
@@ -1345,3 +1344,68 @@ async def test_tinker_api(self):
13451344
self.assertEqual(response.sequences[0].stop_reason, "length")
13461345
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
13471346
self.assertIsNone(response.topk_prompt_logprobs)
1347+
1348+
# test add remove lora
1349+
from vllm.lora.request import LoRARequest
1350+
1351+
# create a dummy lora adapter with all zero weights
1352+
lora_path_1 = os.path.join(self.config.checkpoint_job_dir, "adapter_1")
1353+
lora_path_2 = os.path.join(self.config.checkpoint_job_dir, "adapter_2")
1354+
_create_adapter(self.config.model.model_path, lora_path_1, "adapter_1")
1355+
_create_adapter(self.config.model.model_path, lora_path_2, "adapter_2")
1356+
lora_1 = LoRARequest(
1357+
lora_name="test_adapter_1",
1358+
lora_int_id=1,
1359+
lora_path=os.path.join(lora_path_1, "adapter_1"),
1360+
)
1361+
lora_2 = LoRARequest(
1362+
lora_name="test_adapter_2",
1363+
lora_int_id=2,
1364+
lora_path=os.path.join(lora_path_2, "adapter_2"),
1365+
)
1366+
response = await engine.sample.remote(
1367+
prompt=prompt,
1368+
num_samples=1,
1369+
sampling_params=types.SamplingParams(max_tokens=1),
1370+
include_prompt_logprobs=True,
1371+
lora_request=lora_1,
1372+
)
1373+
ids = await engine.list_lora_adapters.remote()
1374+
self.assertEqual(ids, [1])
1375+
self.assertEqual(len(response.sequences), 1)
1376+
self.assertEqual(response.sequences[0].stop_reason, "length")
1377+
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
1378+
self.assertIsNone(response.topk_prompt_logprobs)
1379+
response = await engine.sample.remote(
1380+
prompt=prompt,
1381+
num_samples=1,
1382+
sampling_params=types.SamplingParams(max_tokens=1),
1383+
include_prompt_logprobs=True,
1384+
lora_request=lora_2,
1385+
)
1386+
self.assertEqual(len(response.sequences), 1)
1387+
self.assertEqual(response.sequences[0].stop_reason, "length")
1388+
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
1389+
self.assertIsNone(response.topk_prompt_logprobs)
1390+
await engine.remove_lora_adapter.remote(lora_id=1)
1391+
await engine.remove_lora_adapter.remote(lora_id=2)
1392+
ids = await engine.list_lora_adapters.remote()
1393+
self.assertEqual(ids, [])
1394+
1395+
1396+
def _create_adapter(model_path: str, lora_path: str, name: str):
1397+
from peft import LoraConfig, get_peft_model
1398+
from transformers import AutoModelForCausalLM
1399+
1400+
model = AutoModelForCausalLM.from_pretrained(
1401+
model_path,
1402+
device_map="cpu",
1403+
)
1404+
lora_config = LoraConfig(
1405+
r=8,
1406+
lora_alpha=8,
1407+
target_modules=["gate_proj", "up_proj", "down_proj"],
1408+
lora_dropout=0.1,
1409+
)
1410+
lora_model = get_peft_model(model, lora_config, adapter_name=name)
1411+
lora_model.save_pretrained(lora_path)

trinity/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""Trinity-RFT (Reinforcement Fine-Tuning)"""
33

4-
__version__ = "0.4.0"
4+
__version__ = "0.4.1"

trinity/common/models/vllm_model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,35 @@ async def logprobs( # type: ignore [override]
403403
dtype=torch.float32,
404404
)
405405

406+
async def add_lora_adapter(self, lora_request: Any) -> int:
407+
"""Add a LoRA adapter to the vLLM engine.
408+
409+
Args:
410+
lora_request (LoRARequest): The LoRA request.
411+
412+
Returns:
413+
lora_id (int): The LoRA adapter ID.
414+
"""
415+
lora_id = await self.async_llm.add_lora(lora_request)
416+
return lora_id
417+
418+
async def remove_lora_adapter(self, lora_id: int) -> None:
419+
"""Remove a LoRA adapter from the vLLM engine.
420+
421+
Args:
422+
lora_id (int): The LoRA adapter ID.
423+
"""
424+
await self.async_llm.remove_lora(lora_id)
425+
426+
async def list_lora_adapters(self) -> Sequence[int]:
427+
"""List all LoRA adapter IDs in the vLLM engine.
428+
429+
Returns:
430+
lora_ids (List[int]): The list of LoRA adapter IDs.
431+
"""
432+
lora_ids = await self.async_llm.list_loras()
433+
return list(lora_ids)
434+
406435
async def sample(
407436
self,
408437
prompt: Any,

0 commit comments

Comments
 (0)