Skip to content

Commit 9634f6a

Browse files
committed
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/config_manager
2 parents 1b03d27 + a9c650b commit 9634f6a

File tree

7 files changed

+46
-11
lines changed

7 files changed

+46
-11
lines changed

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
data:
22
# basic info
3-
dataset_path: '/PATH/TO/DATASET/'
3+
dataset_path: 'openai/gsm8k'
4+
subset_name: "main"
45
train_split: 'train'
5-
eval_split: ''
6+
eval_split: 'test'
67
format_config:
78
prompt_key: 'question'
89
response_key: 'answer'
@@ -24,7 +25,7 @@ model:
2425
model_path: '/PATH/TO/MODEL/'
2526
max_prompt_tokens: 256
2627
max_response_tokens: 1024
27-
checkpoint_path: '/PATH/TO/CHECKPOINT/'
28+
checkpoint_path: ""
2829
cluster:
2930
node_num: 1
3031
gpu_per_node: 8
@@ -34,7 +35,8 @@ buffer:
3435
train_dataset:
3536
name: gsm8k_buffer
3637
storage_type: queue
37-
path: 'sqlite:////gsm8k.db'
38+
algorithm_type: ppo
39+
path: 'sqlite:///gsm8k.db'
3840
# sft_warmup_dataset: # Uncomment these to enable sft warmup
3941
# name: warmup_data
4042
# storage_type: file

trinity/cli/launcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def run(config_path: str):
129129
data_config.dj_config_path or data_config.dj_process_desc
130130
):
131131
activate_data_module(data_config.data_workflow_url, config_path)
132-
ray.init()
132+
if not ray.is_initialized():
133+
ray.init()
133134
if config.mode == "explore":
134135
explore(config)
135136
elif config.mode == "train":

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class DataConfig:
4545

4646
dataset_path: str = ""
4747
train_split: str = "train"
48+
subset_name: Optional[str] = None
4849
eval_split: Optional[str] = None # TODO: check data format
4950
format_config: FormatConfig = field(default_factory=FormatConfig)
5051

trinity/common/models/vllm_worker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.distributed
66
from vllm.worker.worker import Worker
77

8-
from trinity.utils.distributed import init_process_group
8+
from trinity.utils.distributed import init_process_group, is_ipv6_address
99
from trinity.utils.log import get_logger
1010

1111
logger = get_logger(__name__)
@@ -43,9 +43,15 @@ def init_process_group(
4343
)
4444
self._weight_update_rank = torch.distributed.get_rank() + rank_offset
4545

46+
if is_ipv6_address(master_address):
47+
# using tcp://ipv6:port will lead to ValueError
48+
init_method = f"tcp://[{master_address}]:{master_port}"
49+
else:
50+
init_method = f"tcp://{master_address}:{master_port}"
51+
4652
self._model_update_group = init_process_group(
4753
backend=backend,
48-
init_method=f"tcp://{master_address}:{master_port}",
54+
init_method=init_method,
4955
world_size=world_size,
5056
rank=self._weight_update_rank,
5157
group_name=group_name,

trinity/common/task.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def task_generator(
101101
yield task
102102

103103

104+
def load_hf_dataset(config: DataConfig, split: str):
105+
"""Load a Hugging Face dataset with optional configuration name."""
106+
if config.subset_name is not None:
107+
hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split)
108+
else:
109+
hf_dataset = load_dataset(config.dataset_path, split=split)
110+
return hf_dataset
111+
112+
104113
@dataclass
105114
class TaskSet:
106115
"""A TaskSet class that defines a set of tasks and their associated reward functions."""
@@ -125,7 +134,8 @@ def load(
125134
# disable datasets caching to avoid reuse old-version dataset
126135
datasets.disable_caching()
127136
if task_type == TaskType.EVAL:
128-
dataset = load_dataset(config.dataset_path)[config.eval_split]
137+
assert config.eval_split is not None, "eval_split must be provided for eval taskset."
138+
dataset = load_hf_dataset(config, config.eval_split)
129139
else: # default
130140
if task_type != TaskType.EVAL and config.db_url != "":
131141
logger.info(f"Loading dataset from database with url: {config.db_url}")
@@ -134,7 +144,7 @@ def load(
134144
dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
135145
elif config.dataset_path != "":
136146
logger.info(f"Loading dataset from local file with path: {config.dataset_path}.")
137-
dataset = load_dataset(config.dataset_path)[config.train_split]
147+
dataset = load_hf_dataset(config, config.train_split)
138148
else:
139149
raise ValueError("No dataset path or db url provided.")
140150
datasets.enable_caching()

trinity/trainer/verl/fsdp_workers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
5252

5353
from trinity.common.constants import AlgorithmType
54-
from trinity.utils.distributed import init_process_group
54+
from trinity.utils.distributed import init_process_group, is_ipv6_address
5555

5656
logger = logging.getLogger(__file__)
5757
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
@@ -592,9 +592,15 @@ def setup_weight_sync_group(self):
592592
setup_ref = explorer.setup_weight_sync_group.remote(
593593
master_address, master_port, self.state_dict_meta
594594
)
595+
if is_ipv6_address(master_address):
596+
# using tcp://ipv6:port will lead to ValueError
597+
init_method = f"tcp://[{master_address}]:{master_port}"
598+
else:
599+
init_method = f"tcp://{master_address}:{master_port}"
600+
595601
self._model_update_group = init_process_group(
596602
backend=backend,
597-
init_method=f"tcp://{master_address}:{master_port}",
603+
init_method=init_method,
598604
world_size=world_size,
599605
rank=0,
600606
group_name=group_name,

trinity/utils/distributed.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""For distributed training with multiple process groups."""
3+
import ipaddress
34
from datetime import timedelta
45
from typing import Any, Optional, Union
56

@@ -15,6 +16,14 @@
1516
)
1617

1718

19+
def is_ipv6_address(ip_str: str) -> bool:
20+
try:
21+
ip = ipaddress.ip_address(ip_str)
22+
return isinstance(ip, ipaddress.IPv6Address)
23+
except ValueError:
24+
return False
25+
26+
1827
def init_process_group(
1928
backend: Union[str, Backend] = None,
2029
init_method: Optional[str] = None,

0 commit comments

Comments
 (0)