Skip to content

Commit e02f4a5

Browse files
峯回楚财
authored andcommitted
PullRequest: 941 新增GSM8K奖励函数并修改训练配置和流程
Merge branch test/tmp1102 of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/941 Reviewed-by: 楚财 <[email protected]> * . * . * . * . * . * .
1 parent 331bc2f commit e02f4a5

File tree

11 files changed

+314
-261
lines changed

11 files changed

+314
-261
lines changed

areal/controller/rollout_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def _commit_one_to_runner(self):
335335

336336
# Choose worker via round-robin
337337
worker = self._choose_worker()
338-
338+
self.logger.info(f"Submit rollout to worker {worker.id}, task_input: {task_input}")
339339
self.scheduler.call_engine(
340340
worker.id,
341341
"submit",

areal/examples/configs/my001/on_policy.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ gconfig:
4141
# Due to the limitations of sglang, max_new_tokens + max_prompt_len must be less than the model's context_len (set in the model's config.json),
4242
# and cannot be equal to it. See https://github.com/sgl-project/sglang/blob/f98366604b23e331422bf3c62d4e7410ae4fab87/python/sglang/srt/managers/tokenizer_manager.py#L638C9-L638C11
4343
max_new_tokens: 15360
44+
max_tokens: 16383
4445
greedy: false
4546
temperature: 1.0
4647
top_k: 1000000

areal/examples/grpo_trainer.py

Lines changed: 273 additions & 233 deletions
Large diffs are not rendered by default.

areal/extension/asystem/ascheduler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, config: dict[str, Any]):
9797
f"AsystemScheduler initialized for {self.run_name}. API URL: {self.api_url}"
9898
)
9999

100-
def batch_cleanup_jobs(self, signum):
100+
def batch_cleanup_jobs(self, signum, frame):
101101
logger.info(f"signum {signum} received: handle_signals starts")
102102
for role, job_uid in self.submitted_jobs.items():
103103
try:

areal/extension/asystem/ascheduler/rpc_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async def async_create_engine(self, worker_id, engine, *args, **kwargs):
115115
last_exception = EngineCreationError(
116116
worker_id, f"Connection error: {str(e)}"
117117
)
118-
logger.error(f"Connection error on attempt {attempt + 1}: {e}")
118+
logger.warning(f"Connection error on attempt {attempt + 1}: {e}")
119119

120120
except httpx.TimeoutException as e:
121121
# Timeout errors are retryable
@@ -140,7 +140,7 @@ async def async_create_engine(self, worker_id, engine, *args, **kwargs):
140140
if last_exception is not None:
141141
if attempt < max_retries - 1:
142142
logger.warning(
143-
f"Retrying create_engine in 1 second... ({attempt + 1}/{max_retries})"
143+
f"Retrying create_engine in 5 second... ({attempt + 1}/{max_retries})"
144144
)
145145
await asyncio.sleep(5)
146146
continue

areal/extension/asystem/ascheduler/scripts/launch-worker.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ if [[ -n "${PORT_LIST}" ]]; then
4747
# 获取第一个端口
4848
FIRST_PORT="${PORTS[0]}"
4949
# 添加到 WORKER_COMMAND
50-
WORKER_COMMAND="/usr/bin/python -u -m areal.scheduler.rpc.async_rpc_server --worker-type ${WORKER_TYPE} --worker-index ${WORKER_INDEX} --port ${FIRST_PORT}"
50+
WORKER_COMMAND="/usr/bin/python -u -m areal.scheduler.rpc.rpc_server --worker-type ${WORKER_TYPE} --worker-index ${WORKER_INDEX} --port ${FIRST_PORT}"
5151
else
52-
WORKER_COMMAND="/usr/bin/python -u -m areal.scheduler.rpc.async_rpc_server --worker-type ${WORKER_TYPE} --worker-index ${WORKER_INDEX}"
52+
WORKER_COMMAND="/usr/bin/python -u -m areal.scheduler.rpc.rpc_server --worker-type ${WORKER_TYPE} --worker-index ${WORKER_INDEX}"
5353
fi
5454

5555
#log output to local worker dir

areal/extension/asystem/math_reward.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ async def reward_fn(
3838

3939
format_rewards = []
4040

41-
query_id = kwargs.get("query_id")[0]
42-
task = kwargs.get("task")[0]
41+
query_id = kwargs.get("query_id")
42+
task = kwargs.get("task")
4343
answers = [completion]
4444
query_id_strs = [query_id]
4545

@@ -59,7 +59,7 @@ async def reward_fn(
5959
elif task == "ifeval":
6060
format_rewards = await ifeval_verify(id2info, answers, query_id_strs)
6161
elif task == "swe":
62-
extra_info = kwargs.get("extra_info")[0]
62+
extra_info = kwargs.get("extra_info")
6363
if extra_info and extra_info.get("provider", "functioncall") == "local":
6464
format_rewards = await local_swe_verify(id2info, answers, query_id_strs)
6565
else:
@@ -224,8 +224,8 @@ def extract_python_code(text, min_length=20, strict_syntax=False):
224224
async def main():
225225
answer = "<answer>\n28\n</answer>"
226226
data = {
227-
"task": ["general"],
228-
"query_id": ["general-42941"],
227+
"task": "general",
228+
"query_id": "general-42941",
229229
"prompt": [
230230
"<role>HUMAN</role>33岁孩子不听话,如何处理父子之间矛盾?<role>ASSISTANT</role>"
231231
],

areal/extension/asystem/remote_hybrid_inference_worker.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
77
from dataclasses import dataclass
88
from queue import Empty, Full, Queue
9-
from typing import TYPE_CHECKING, Any, Optional
9+
from typing import Any, Optional
1010

1111
import aiohttp
1212
import requests
@@ -22,15 +22,14 @@
2222
RolloutStat,
2323
WeightUpdateMeta,
2424
)
25+
from areal.api.workflow_api import RolloutWorkflow
2526
from areal.extension.asystem.api.cli_args import RemoteHybridInferenceConfig
2627
from areal.extension.asystem.util import wait_future_ordered
2728
from areal.utils import logging, seeding
2829
from areal.utils.data import concat_padded_tensors, cycle_dataloader
2930
from areal.utils.errors import EngineError, FrameworkError
3031
from areal.utils.http import arequest_with_retry, get_default_connector
3132

32-
if TYPE_CHECKING:
33-
from areal.api.workflow_api import RolloutWorkflow
3433
logger = logging.getLogger(__name__)
3534

3635
ROLLOUT_POLL_WAIT_TIME = 0.05
@@ -236,9 +235,7 @@ async def _rollout_thread_async(self):
236235
# logger.info(f"Get data from puller: {data}")
237236
task = asyncio.create_task(
238237
(
239-
workflow.arun_episodes(self, data)
240-
if isinstance(data, list)
241-
else workflow.arun_episode(self, data)
238+
workflow.arun_episode(self, data)
242239
),
243240
name=str(rid),
244241
)
@@ -345,6 +342,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
345342
start_time = time.perf_counter()
346343
accumulated_output_tokens = []
347344
accumulated_output_logprobs = []
345+
accumulated_versions = []
348346

349347
# Deal with rollout interruption
350348
stop_reason = ""
@@ -385,6 +383,9 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
385383
# Update accumulated outputs
386384
accumulated_output_tokens.extend(output_tokens)
387385
accumulated_output_logprobs.extend(output_logprobs)
386+
accumulated_versions.extend(
387+
[self.get_version()] * len(output_logprobs)
388+
)
388389

389390
# Check if generation is complete
390391
finish_reason = meta_info["finish_reason"]
@@ -399,7 +400,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
399400
input_tokens=req.input_ids,
400401
output_tokens=accumulated_output_tokens,
401402
output_logprobs=accumulated_output_logprobs,
402-
output_version=self.get_version(),
403+
output_versions=accumulated_versions,
403404
stop_reason=stop_reason,
404405
latency=latency,
405406
ttft=latency, # Simplified for non-streaming
@@ -532,14 +533,13 @@ def update_weights_from_disk(self, addr, path: str):
532533

533534
def submit(
534535
self,
535-
data: list[dict[str, Any]] | dict[str, Any],
536-
workflow: "RolloutWorkflow",
536+
data: dict[str, Any],
537+
workflow: RolloutWorkflow | None = None,
538+
workflow_builder: Callable | None = None,
539+
should_accept: Callable | None = None,
537540
) -> None:
538541
try:
539-
if not isinstance(data, list):
540-
data = [data]
541-
for d in data:
542-
self.input_queue.put_nowait((d, workflow))
542+
self.input_queue.put_nowait((data, workflow))
543543
except Full:
544544
raise FrameworkError(
545545
"FrameworkError",
@@ -548,7 +548,7 @@ def submit(
548548
)
549549

550550
def submit_batch(
551-
self, data: list[dict[str, Any]], workflow: "RolloutWorkflow"
551+
self, data: list[dict[str, Any]], workflow: RolloutWorkflow
552552
) -> None:
553553
try:
554554
self.input_queue.put_nowait(data, workflow)
@@ -701,3 +701,11 @@ def notify_event(self, event: str, global_step: int) -> None:
701701
except Exception as e:
702702
raise EngineError("InferenceEngineError", "NotifyEventError", e)
703703
return None
704+
705+
def wait_quiet(
706+
self, count: int, timeout: float | None = None, max_retries: int = 1,
707+
) -> dict[str, Any] | None:
708+
try:
709+
return self.wait(count, timeout=timeout)
710+
except TimeoutError:
711+
return "NO_RESULT"

areal/extension/asystem/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def __init__(self, data_source, seed=42):
5353
self.shuffle_indices = get_shuffle_indices(size=len(data_source), seed=seed)
5454

5555
def __iter__(self):
56-
for idx in self.shuffle_indices:
57-
yield from idx
56+
return iter(self.shuffle_indices)
5857

5958
def __len__(self):
6059
return len(self.data_source)

areal/reward/gsm8k.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
2+
from areal.reward.math_parser import process_results
3+
4+
return int(process_results(completions, answer)[0])

0 commit comments

Comments
 (0)