Skip to content

Commit ca15365

Browse files
XianBWyou-n-g
andauthored
feat: use auto gen seed when using LLM cache (#441)
* initial version * test requirements * fix bugs * fix bugs * add annotation * fix ruff error * fix CI * fix CI * fix CI * fix CI * change random usage * move cache_seed_gen to core/utils.py * fix CI * change cache_seed_gen name --------- Co-authored-by: Young <[email protected]>
1 parent c3fa245 commit ca15365

File tree

4 files changed

+189
-3
lines changed

4 files changed

+189
-3
lines changed

rdagent/core/conf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ class RDAgentSettings(BaseSettings):
1515
# TODO: (xiao) think it can be a separate config.
1616
log_trace_path: str | None = None
1717

18+
# Behavior of returning answers to the same question when caching is enabled
19+
use_auto_chat_cache_seed_gen: bool = False
20+
"""
21+
`_create_chat_completion_inner_function` provdies a feature to pass in a seed to affect the cache hash key
22+
We want to enable a auto seed generator to get different default seed for `_create_chat_completion_inner_function`
23+
if seed is not given.
24+
So the cache will only not miss you ask the same question on same round.
25+
"""
26+
init_chat_cache_seed: int = 42
27+
1828
# azure document intelligence configs
1929
azure_document_intelligence_key: str = ""
2030
azure_document_intelligence_endpoint: str = ""

rdagent/core/utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import multiprocessing as mp
77
import pickle
8+
import random
89
from collections.abc import Callable
910
from pathlib import Path
1011
from typing import Any, ClassVar, NoReturn, cast
@@ -86,11 +87,48 @@ class of `class_path`
8687
return getattr(module, class_name)
8788

8889

90+
class CacheSeedGen:
91+
"""
92+
It is a global seed generator to generate a sequence of seeds.
93+
This will support the feature `use_auto_chat_cache_seed_gen` claim
94+
95+
NOTE:
96+
- This seed is specifically for the cache and is different from a regular seed.
97+
- If the cache is removed, setting the same seed will not produce the same QA trace.
98+
"""
99+
100+
def __init__(self) -> None:
101+
self.set_seed(RD_AGENT_SETTINGS.init_chat_cache_seed)
102+
103+
def set_seed(self, seed: int) -> None:
104+
random.seed(seed)
105+
106+
def get_next_seed(self) -> int:
107+
"""generate next random int"""
108+
return random.randint(0, 10000) # noqa: S311
109+
110+
111+
LLM_CACHE_SEED_GEN = CacheSeedGen()
112+
113+
114+
def _subprocess_wrapper(f: Callable, seed: int, args: list) -> Any:
115+
"""
116+
It is a function wrapper. To ensure the subprocess has a fixed start seed.
117+
"""
118+
119+
LLM_CACHE_SEED_GEN.set_seed(seed)
120+
return f(*args)
121+
122+
89123
def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) -> list:
90124
"""It will use multiprocessing to call the functions in func_calls with the given parameters.
91125
The results equals to `return [f(*args) for f, args in func_calls]`
92126
It will not call multiprocessing if `n=1`
93127
128+
NOTE:
129+
We coooperate with chat_cache_seed feature
130+
We ensure get the same seed trace even we have multiple number of seed
131+
94132
Parameters
95133
----------
96134
func_calls : List[Tuple[Callable, Tuple]]
@@ -105,8 +143,12 @@ def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) ->
105143
"""
106144
if n == 1:
107145
return [f(*args) for f, args in func_calls]
146+
108147
with mp.Pool(processes=max(1, min(n, len(func_calls)))) as pool:
109-
results = [pool.apply_async(f, args) for f, args in func_calls]
148+
results = [
149+
pool.apply_async(_subprocess_wrapper, args=(f, LLM_CACHE_SEED_GEN.get_next_seed(), args))
150+
for f, args in func_calls
151+
]
110152
return [result.get() for result in results]
111153

112154

rdagent/oai/llm_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import hashlib
44
import json
55
import os
6+
import random
67
import re
78
import sqlite3
89
import ssl
@@ -16,7 +17,8 @@
1617
import numpy as np
1718
import tiktoken
1819

19-
from rdagent.core.utils import SingletonBaseClass
20+
from rdagent.core.conf import RD_AGENT_SETTINGS
21+
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass
2022
from rdagent.log import LogColors
2123
from rdagent.log import rdagent_logger as logger
2224
from rdagent.oai.llm_conf import LLM_SETTINGS
@@ -594,7 +596,10 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
594596
To make retries useful, we need to enable a seed.
595597
This seed is different from `self.chat_seed` for GPT. It is for the local cache mechanism enabled by RD-Agent locally.
596598
"""
597-
# TODO: we can add this function back to avoid so much `LLM_SETTINGS.log_llm_chat_content`
599+
if seed is None and RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen:
600+
seed = LLM_CACHE_SEED_GEN.get_next_seed()
601+
602+
# TODO: we can add this function back to avoid so much `self.cfg.log_llm_chat_content`
598603
if LLM_SETTINGS.log_llm_chat_content:
599604
logger.info(self._build_log_messages(messages), tag="llm_messages")
600605
# TODO: fail to use loguru adaptor due to stream response

test/oai/test_completion.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
from rdagent.oai.llm_utils import APIBackend
66

77

8+
def _worker(system_prompt, user_prompt):
9+
api = APIBackend()
10+
return api.build_messages_and_create_chat_completion(
11+
system_prompt=system_prompt,
12+
user_prompt=user_prompt,
13+
)
14+
15+
816
class TestChatCompletion(unittest.TestCase):
917
def test_chat_completion(self) -> None:
1018
system_prompt = "You are a helpful assistant."
@@ -45,6 +53,127 @@ def test_chat_multi_round(self) -> None:
4553
response2 = session.build_chat_completion(user_prompt=user_prompt_2)
4654
assert response2 is not None
4755

56+
def test_chat_cache(self) -> None:
57+
"""
58+
Tests:
59+
- Single process, ask same question, enable cache
60+
- 2 pass
61+
- cache is not missed & same question get different answer.
62+
"""
63+
from rdagent.core.conf import RD_AGENT_SETTINGS
64+
from rdagent.core.utils import LLM_CACHE_SEED_GEN
65+
from rdagent.oai.llm_conf import LLM_SETTINGS
66+
67+
system_prompt = "You are a helpful assistant."
68+
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"
69+
70+
origin_value = (
71+
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
72+
LLM_SETTINGS.use_chat_cache,
73+
LLM_SETTINGS.dump_chat_cache,
74+
)
75+
76+
LLM_SETTINGS.use_chat_cache = True
77+
LLM_SETTINGS.dump_chat_cache = True
78+
79+
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True
80+
81+
LLM_CACHE_SEED_GEN.set_seed(10)
82+
response1 = APIBackend().build_messages_and_create_chat_completion(
83+
system_prompt=system_prompt,
84+
user_prompt=user_prompt,
85+
)
86+
response2 = APIBackend().build_messages_and_create_chat_completion(
87+
system_prompt=system_prompt,
88+
user_prompt=user_prompt,
89+
)
90+
91+
LLM_CACHE_SEED_GEN.set_seed(20)
92+
response3 = APIBackend().build_messages_and_create_chat_completion(
93+
system_prompt=system_prompt,
94+
user_prompt=user_prompt,
95+
)
96+
response4 = APIBackend().build_messages_and_create_chat_completion(
97+
system_prompt=system_prompt,
98+
user_prompt=user_prompt,
99+
)
100+
101+
LLM_CACHE_SEED_GEN.set_seed(10)
102+
response5 = APIBackend().build_messages_and_create_chat_completion(
103+
system_prompt=system_prompt,
104+
user_prompt=user_prompt,
105+
)
106+
response6 = APIBackend().build_messages_and_create_chat_completion(
107+
system_prompt=system_prompt,
108+
user_prompt=user_prompt,
109+
)
110+
111+
# Reset, for other tests
112+
(
113+
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
114+
LLM_SETTINGS.use_chat_cache,
115+
LLM_SETTINGS.dump_chat_cache,
116+
) = origin_value
117+
118+
assert (
119+
response1 != response3 and response2 != response4
120+
), "Responses sequence should be determined by 'init_chat_cache_seed'"
121+
assert (
122+
response1 == response5 and response2 == response6
123+
), "Responses sequence should be determined by 'init_chat_cache_seed'"
124+
assert (
125+
response1 != response2 and response3 != response4 and response5 != response6
126+
), "Same question should get different response when use_auto_chat_cache_seed_gen=True"
127+
128+
def test_chat_cache_multiprocess(self) -> None:
129+
"""
130+
Tests:
131+
- Multi process, ask same question, enable cache
132+
- 2 pass
133+
- cache is not missed & same question get different answer.
134+
"""
135+
from rdagent.core.conf import RD_AGENT_SETTINGS
136+
from rdagent.core.utils import LLM_CACHE_SEED_GEN, multiprocessing_wrapper
137+
from rdagent.oai.llm_conf import LLM_SETTINGS
138+
139+
system_prompt = "You are a helpful assistant."
140+
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"
141+
142+
origin_value = (
143+
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
144+
LLM_SETTINGS.use_chat_cache,
145+
LLM_SETTINGS.dump_chat_cache,
146+
)
147+
148+
LLM_SETTINGS.use_chat_cache = True
149+
LLM_SETTINGS.dump_chat_cache = True
150+
151+
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True
152+
153+
func_calls = [(_worker, (system_prompt, user_prompt)) for _ in range(4)]
154+
155+
LLM_CACHE_SEED_GEN.set_seed(10)
156+
responses1 = multiprocessing_wrapper(func_calls, n=4)
157+
LLM_CACHE_SEED_GEN.set_seed(20)
158+
responses2 = multiprocessing_wrapper(func_calls, n=4)
159+
LLM_CACHE_SEED_GEN.set_seed(10)
160+
responses3 = multiprocessing_wrapper(func_calls, n=4)
161+
162+
# Reset, for other tests
163+
(
164+
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
165+
LLM_SETTINGS.use_chat_cache,
166+
LLM_SETTINGS.dump_chat_cache,
167+
) = origin_value
168+
for i in range(len(func_calls)):
169+
assert (
170+
responses1[i] != responses2[i] and responses1[i] == responses3[i]
171+
), "Responses sequence should be determined by 'init_chat_cache_seed'"
172+
for j in range(i + 1, len(func_calls)):
173+
assert (
174+
responses1[i] != responses1[j] and responses2[i] != responses2[j]
175+
), "Same question should get different response when use_auto_chat_cache_seed_gen=True"
176+
48177

49178
if __name__ == "__main__":
50179
unittest.main()

0 commit comments

Comments
 (0)