Skip to content

Commit a42edb2

Browse files
峯回daihao
authored andcommitted
PullRequest: 928 update extension
Merge branch fh/devgh of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/928 Reviewed-by: 楚财 <[email protected]> * update extension
1 parent c5fb427 commit a42edb2

File tree

5 files changed

+685
-44
lines changed

5 files changed

+685
-44
lines changed

areal/api/cli_args.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,45 +1383,3 @@ def save_config(cfg, log_dir):
13831383
default_flow_style=False,
13841384
sort_keys=False,
13851385
)
1386-
1387-
########### asystem ###########
1388-
from typing import Dict
1389-
class RemoteHybridInferenceConfig(InferenceEngineConfig):
1390-
experiment_name: str = MISSING
1391-
trial_name: str = MISSING
1392-
model_path: str = field(
1393-
default=MISSING,
1394-
metadata={"help": "model path"},
1395-
)
1396-
storage_path: str = field(
1397-
default=MISSING,
1398-
metadata={"help": "storage path"},
1399-
)
1400-
random_seed: int = field(
1401-
default=0,
1402-
metadata={"help": "random seed"},
1403-
)
1404-
engine_config: Dict = field(default_factory=dict)
1405-
dp_size: int = field(
1406-
default=1,
1407-
metadata={"help": "dp size"},
1408-
)
1409-
pp_size: int = field(
1410-
default=1,
1411-
metadata={"help": "pp size"},
1412-
)
1413-
tp_size: int = field(
1414-
default=1,
1415-
metadata={"help": "tp size"},
1416-
)
1417-
seed: int = field(
1418-
default=1,
1419-
metadata={"help": "seed"},
1420-
)
1421-
batch_requests: bool = field(
1422-
default=False,
1423-
metadata={"help": "batch requests"},
1424-
)
1425-
request_timeout: float = field(
1426-
default=7200.0, metadata={"help": "Timeout for HTTP requests."}
1427-
)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from dataclasses import MISSING, dataclass, field
2+
3+
from areal.api.cli_args import GRPOConfig as BaseGRPOConfig
4+
from areal.api.cli_args import InferenceEngineConfig, SchedulerConfig
5+
from areal.api.cli_args import TrainEngineConfig as BaseTrainEngineConfig
6+
7+
8+
class RemoteHybridInferenceConfig(InferenceEngineConfig):
9+
model_path: str = field(
10+
default=MISSING,
11+
metadata={"help": "model path"},
12+
)
13+
storage_path: str = field(
14+
default=MISSING,
15+
metadata={"help": "storage path"},
16+
)
17+
random_seed: int = field(
18+
default=0,
19+
metadata={"help": "random seed"},
20+
)
21+
engine_config: dict = field(default_factory=dict)
22+
dp_size: int = field(
23+
default=1,
24+
metadata={"help": "dp size"},
25+
)
26+
pp_size: int = field(
27+
default=1,
28+
metadata={"help": "pp size"},
29+
)
30+
tp_size: int = field(
31+
default=1,
32+
metadata={"help": "tp size"},
33+
)
34+
seed: int = field(
35+
default=1,
36+
metadata={"help": "seed"},
37+
)
38+
batch_requests: bool = field(
39+
default=False,
40+
metadata={"help": "batch requests"},
41+
)
42+
43+
44+
@dataclass
45+
class RemoteMegatronWrapPolicy:
46+
n_minibatches: int = 1
47+
kl_ctl: float = 0.0
48+
adv_norm: bool = False
49+
discount: float = 1.0
50+
gae_lambda: float = 1.0
51+
eps_clip: float = 0.2
52+
clip_ratio_low: float = 0.2
53+
clip_ratio_high: float = 0.28
54+
c_clip: float | None = None
55+
value_eps_clip: float = 0.2
56+
max_reward_clip: float = 5.0
57+
disable_value: bool = True
58+
early_stop_kl: float | None = None
59+
early_stop_imp_ratio: float | None = None
60+
adaptive_kl_ctl: bool = False
61+
adaptive_kl_target: float | None = 6
62+
adaptive_kl_horizon: float | None = 10000
63+
enable_save: bool = True
64+
value_norm: bool = True
65+
value_norm_type: str = field(metadata={"choices": ["exp", "ma"]}, default="exp")
66+
value_norm_beta: float = 0.99995
67+
value_norm_eps: float = 1e-5
68+
group_size: int = 8
69+
generation_size: int | None = None
70+
mask_no_eos_with_zero: bool = False
71+
group_adv_norm: bool = True
72+
mask_too_long: bool = False
73+
use_dense_reward: bool = False
74+
reward_delta: bool = True
75+
token_normalize_scope: str = field(
76+
default="global", metadata={"choices": ["global", "dp"]}
77+
)
78+
sample_reuse: int = 1
79+
temperature: float = 1.0 # GenerationHyperparameters
80+
reward_output_scaling: float = field(
81+
default=1.0, metadata={"help": "Reward scaling factor"}
82+
)
83+
reward_output_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
84+
recompute_logp: bool = False
85+
86+
87+
@dataclass
88+
class RemoteMegatronEngineConfig:
89+
wrap_policy: RemoteMegatronWrapPolicy | None = field(
90+
default_factory=RemoteMegatronWrapPolicy,
91+
metadata={"help": "RemoteMegatron wrap policy."},
92+
)
93+
remote_megatron_config: dict = field(default_factory=dict)
94+
loss_configs: dict = field(default_factory=dict)
95+
recover_dir: str = field(default="")
96+
97+
@staticmethod
98+
def assign_wrap_policy(policy_dict: dict) -> RemoteMegatronWrapPolicy:
99+
"""Assign values from dictionary to RemoteMegatronWrapPolicy fields.
100+
101+
Args:
102+
policy_dict: Dictionary containing wrap policy configuration
103+
104+
Returns:
105+
Configured RemoteMegatronWrapPolicy instance
106+
"""
107+
policy = RemoteMegatronWrapPolicy()
108+
for field_name, field_value in policy_dict.items():
109+
if hasattr(policy, field_name):
110+
setattr(policy, field_name, field_value)
111+
return policy
112+
113+
experiment_name: str = field(
114+
default="test-exp",
115+
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
116+
)
117+
trial_name: str = field(
118+
default="test-trial",
119+
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
120+
)
121+
group_size: int = field(
122+
default=8,
123+
metadata={"help": "Number of answers retained per prompt (best-of-n)."},
124+
)
125+
train_bs_n_seqs: int = field(
126+
default=32, metadata={"help": "Training batch size in number of sequences"}
127+
)
128+
n_mbs: int = field(
129+
default=1,
130+
metadata={
131+
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
132+
},
133+
)
134+
max_tokens_per_mb: int = field(
135+
default=16384,
136+
metadata={
137+
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
138+
},
139+
)
140+
global_step: int = field(
141+
default=0,
142+
metadata={
143+
"help": "global step for recover",
144+
},
145+
)
146+
147+
148+
class TrainEngineConfig(BaseTrainEngineConfig):
149+
hybrid_engine: RemoteMegatronEngineConfig = field(
150+
default_factory=RemoteMegatronEngineConfig
151+
)
152+
153+
154+
@dataclass
155+
class RecoverConfig:
156+
experiment_name: str = field(default="default-experiment")
157+
trial_name: str = field(default="trial0")
158+
fileroot: str = field(default="")
159+
recover_meta_info_path: str = field(default="")
160+
enable_recover: bool = field(default=False)
161+
latest_disable_save_hf: bool = field(
162+
default=True, metadata={"help": "Disable saving latest huggingFace"}
163+
)
164+
periodic_disable_save_hf: bool = field(
165+
default=False, metadata={"help": "Disable saving periodic huggingFace"}
166+
)
167+
periodic_save_interval: int | None = field(
168+
default=None, metadata={"help": "Periodic save steps"}
169+
)
170+
latest_save_interval: int | None = field(
171+
default=None, metadata={"help": "Latest save steps"}
172+
)
173+
174+
175+
@dataclass
176+
class BaseExperimentConfigExtension:
177+
enable_colocate_mode: bool = field(
178+
default=False, metadata={"help": "Enable colocate mode."}
179+
)
180+
storage_prefix: str = field(
181+
default="", metadata={"help": "Storage prefix for colocate mode."}
182+
)
183+
weight_update_type: str = field(default="nccl", metadata={"help": "nccl/disk"})
184+
185+
scheduler: SchedulerConfig = field(
186+
default_factory=SchedulerConfig, metadata={"help": "Scheduler config."}
187+
)
188+
189+
190+
@dataclass
191+
class GRPOConfig(BaseGRPOConfig, BaseExperimentConfigExtension):
192+
rollout: RemoteHybridInferenceConfig = field(
193+
default_factory=RemoteHybridInferenceConfig
194+
)
195+
actor: TrainEngineConfig = field(default_factory=TrainEngineConfig)
196+
ref: TrainEngineConfig = field(default_factory=TrainEngineConfig)
197+
recover: RecoverConfig = field(default_factory=RecoverConfig)

0 commit comments

Comments
 (0)