|
49 | 49 | from areal.api.engine_api import InferenceEngine |
50 | 50 |
|
51 | 51 | # reset OpenAI keys when using the wrapped client. |
52 | | -os.environ["OPENAI_API_KEY"] = "none" |
53 | | -os.environ["OPENAI_BASE_URL"] = "none" |
| 52 | +os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "none") |
| 53 | +os.environ["OPENAI_BASE_URL"] = os.environ.get("OPENAI_BASE_URL", "none") |
54 | 54 |
|
55 | 55 | logger = logging.getLogger("AReaLOpenAI Client") |
56 | 56 |
|
@@ -97,6 +97,7 @@ async def create( |
97 | 97 | tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, |
98 | 98 | top_p: float | None | NotGiven = NOT_GIVEN, |
99 | 99 | extra_body: Body | None = None, |
| 100 | + areal_completion_cache: dict[str, InteractionWithTokenLogpReward] | None = None, |
100 | 101 | **kwargs: Any, |
101 | 102 | ) -> ChatCompletion: |
102 | 103 | """Override create method to use AReaL engine and cache responses.""" |
@@ -218,7 +219,15 @@ async def create( |
218 | 219 |
|
219 | 220 | if is_omitted(store) or store: |
220 | 221 | # Cache the completion with its input messages |
221 | | - self._cache[completion_id] = InteractionWithTokenLogpReward( |
| 222 | + cache = ( |
| 223 | + areal_completion_cache |
| 224 | + if areal_completion_cache is not None |
| 225 | + else self._cache |
| 226 | + ) |
| 227 | + if completion_id in cache: |
| 228 | + raise ValueError(f"Completion {completion_id} already exists in cache") |
| 229 | + |
| 230 | + cache[completion_id] = InteractionWithTokenLogpReward( |
222 | 231 | completion=deepcopy(chat_completion), |
223 | 232 | model_response=response, # Should not deepcopy response because of tokenizer |
224 | 233 | messages=deepcopy(messages_list), # Store a copy of the input messages |
@@ -262,6 +271,7 @@ async def create( |
262 | 271 | temperature: float | None | NotGiven = NOT_GIVEN, |
263 | 272 | top_p: float | None | NotGiven = NOT_GIVEN, |
264 | 273 | extra_body: Body | None = None, |
| 274 | + areal_response_cache: dict[str, InteractionWithTokenLogpReward] | None = None, |
265 | 275 | **kwargs: Any, |
266 | 276 | ) -> Response: |
267 | 277 | """Override create method to use AReaL engine""" |
@@ -490,7 +500,13 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]: |
490 | 500 | ) |
491 | 501 |
|
492 | 502 | # Cache the response with its input data |
493 | | - self._cache[resp_id] = InteractionWithTokenLogpReward( |
| 503 | + cache = ( |
| 504 | + areal_response_cache if areal_response_cache is not None else self._cache |
| 505 | + ) |
| 506 | + if resp_id in cache: |
| 507 | + raise ValueError(f"Response {resp_id} already exists in cache") |
| 508 | + |
| 509 | + cache[resp_id] = InteractionWithTokenLogpReward( |
494 | 510 | response=deepcopy(response), |
495 | 511 | model_response=engine_resp, # Should not deepcopy because of tokenizer |
496 | 512 | input_data=( |
|
0 commit comments