Skip to content

Commit f24294b

Browse files
dhh1995yulangzgarrett4wade
authored
feat: support proxy server and client for training openai-compatible agents (#500)
* feat: support proxy server and client for training openai-compatible agents * change proxy client to proxy session * add gitignore * add note for client example * launch agent in subprocess (#506) * support proxy agent in subprocess mode --------- Co-authored-by: 仲青 <[email protected]> * move proxy agents examples into experimental * add notes and remove experimental examples * add warning message * merge dynamic importing stuffs --------- Co-authored-by: yulangz <[email protected]> Co-authored-by: 仲青 <[email protected]> Co-authored-by: garrett4wade <[email protected]>
1 parent e434995 commit f24294b

File tree

9 files changed

+1484
-11
lines changed

9 files changed

+1484
-11
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
.data/
44
.idea/
55

6+
7+
# Ruff
8+
.ruff_cache/
9+
610
# Mac
711
.DS_Store
812

areal/api/cli_args.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import argparse
22
import json
33
import os
4-
from dataclasses import asdict, dataclass, field
4+
from dataclasses import MISSING as dataclass_missing
5+
from dataclasses import asdict, dataclass, field, fields
56
from pathlib import Path
67
from typing import Any
78

@@ -13,11 +14,13 @@
1314
from omegaconf import MISSING, DictConfig, OmegaConf
1415

1516
from areal.platforms import current_platform
16-
from areal.utils import name_resolve, pkg_version
17+
from areal.utils import logging, name_resolve, pkg_version
1718
from areal.utils.pkg_version import is_version_less
1819

1920
uvloop.install()
2021

22+
logger = logging.getLogger("CLI args")
23+
2124

2225
@dataclass
2326
class NormConfig:
@@ -160,6 +163,47 @@ def new(self, **kwargs):
160163
args.update(kwargs)
161164
return GenerationHyperparameters(**args)
162165

166+
def to_openai_args_dict(
167+
self, exclude_args: list[str] | None = None
168+
) -> dict[str, Any]:
169+
"""Convert the generation hyperparameters to a dictionary of arguments for OpenAI client."""
170+
final_exclude_args = set(exclude_args) if exclude_args is not None else set()
171+
final_exclude_args.update(
172+
{
173+
"min_new_tokens", # Not supported by OpenAI
174+
"greedy", # Not directly supported by OpenAI
175+
"top_k", # Not supported by OpenAI
176+
"stop_token_ids", # Not supported by OpenAI
177+
}
178+
)
179+
180+
mapping = {
181+
"n_samples": "n",
182+
"max_new_tokens": "max_completion_tokens",
183+
}
184+
res = {}
185+
for k, v in asdict(self).items():
186+
if k in final_exclude_args:
187+
should_warn = False
188+
189+
current_value = getattr(self, k)
190+
f = next(_field for _field in fields(self) if _field.name == k)
191+
192+
# Check if equal to the default value
193+
if f.default is not dataclass_missing:
194+
if current_value != f.default:
195+
should_warn = True
196+
elif f.default_factory is not dataclass_missing:
197+
if current_value != f.default_factory():
198+
should_warn = True
199+
200+
if should_warn:
201+
logger.warning(f"Unsupported arg for openai format: `{k}`")
202+
continue
203+
res[mapping.get(k, k)] = v
204+
205+
return res
206+
163207

164208
# Train Engine Configs
165209

areal/experimental/openai/client.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
from areal.api.engine_api import InferenceEngine
5050

5151
# reset OpenAI keys when using the wrapped client.
52-
os.environ["OPENAI_API_KEY"] = "none"
53-
os.environ["OPENAI_BASE_URL"] = "none"
52+
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "none")
53+
os.environ["OPENAI_BASE_URL"] = os.environ.get("OPENAI_BASE_URL", "none")
5454

5555
logger = logging.getLogger("AReaLOpenAI Client")
5656

@@ -97,6 +97,7 @@ async def create(
9797
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
9898
top_p: float | None | NotGiven = NOT_GIVEN,
9999
extra_body: Body | None = None,
100+
areal_completion_cache: dict[str, InteractionWithTokenLogpReward] | None = None,
100101
**kwargs: Any,
101102
) -> ChatCompletion:
102103
"""Override create method to use AReaL engine and cache responses."""
@@ -218,7 +219,15 @@ async def create(
218219

219220
if is_omitted(store) or store:
220221
# Cache the completion with its input messages
221-
self._cache[completion_id] = InteractionWithTokenLogpReward(
222+
cache = (
223+
areal_completion_cache
224+
if areal_completion_cache is not None
225+
else self._cache
226+
)
227+
if completion_id in cache:
228+
raise ValueError(f"Completion {completion_id} already exists in cache")
229+
230+
cache[completion_id] = InteractionWithTokenLogpReward(
222231
completion=deepcopy(chat_completion),
223232
model_response=response, # Should not deepcopy response because of tokenizer
224233
messages=deepcopy(messages_list), # Store a copy of the input messages
@@ -262,6 +271,7 @@ async def create(
262271
temperature: float | None | NotGiven = NOT_GIVEN,
263272
top_p: float | None | NotGiven = NOT_GIVEN,
264273
extra_body: Body | None = None,
274+
areal_response_cache: dict[str, InteractionWithTokenLogpReward] | None = None,
265275
**kwargs: Any,
266276
) -> Response:
267277
"""Override create method to use AReaL engine"""
@@ -490,7 +500,13 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
490500
)
491501

492502
# Cache the response with its input data
493-
self._cache[resp_id] = InteractionWithTokenLogpReward(
503+
cache = (
504+
areal_response_cache if areal_response_cache is not None else self._cache
505+
)
506+
if resp_id in cache:
507+
raise ValueError(f"Response {resp_id} already exists in cache")
508+
509+
cache[resp_id] = InteractionWithTokenLogpReward(
494510
response=deepcopy(response),
495511
model_response=engine_resp, # Should not deepcopy because of tokenizer
496512
input_data=(

0 commit comments

Comments
 (0)