Skip to content

Commit 22f4693

Browse files
feat: add rpm & tpm limit
1 parent 3c41685 commit 22f4693

File tree

4 files changed

+146
-5
lines changed

4 files changed

+146
-5
lines changed

graphgen/models/llm/limitter.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import time
2+
from datetime import datetime, timedelta
3+
import asyncio
4+
5+
from graphgen.utils import logger
6+
7+
8+
class RPM:
9+
10+
def __init__(self, rpm: int = 1000):
11+
self.rpm = rpm
12+
self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
13+
14+
def get_minute_slot(self):
15+
current_time = time.time()
16+
dt_object = datetime.fromtimestamp(current_time)
17+
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
18+
return total_minutes_since_midnight
19+
20+
async def wait(self, silent=False):
21+
current = time.time()
22+
dt_object = datetime.fromtimestamp(current)
23+
minute_slot = self.get_minute_slot()
24+
25+
if self.record['rpm_slot'] == minute_slot:
26+
# check RPM exceed
27+
if self.record['counter'] >= self.rpm:
28+
# wait until next minute
29+
next_minute = dt_object.replace(
30+
second=0, microsecond=0) + timedelta(minutes=1)
31+
_next = next_minute.timestamp()
32+
sleep_time = abs(_next - current)
33+
if not silent:
34+
logger.info('RPM sleep %s', sleep_time)
35+
await asyncio.sleep(sleep_time)
36+
37+
self.record = {
38+
'rpm_slot': self.get_minute_slot(),
39+
'counter': 0
40+
}
41+
else:
42+
self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
43+
self.record['counter'] += 1
44+
45+
if not silent:
46+
logger.debug(self.record)
47+
48+
49+
class TPM:
50+
51+
def __init__(self, tpm: int = 20000):
52+
self.tpm = tpm
53+
self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0}
54+
55+
def get_minute_slot(self):
56+
current_time = time.time()
57+
dt_object = datetime.fromtimestamp(current_time)
58+
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
59+
return total_minutes_since_midnight
60+
61+
async def wait(self, token_count, silent=False):
62+
current = time.time()
63+
dt_object = datetime.fromtimestamp(current)
64+
minute_slot = self.get_minute_slot()
65+
66+
# get next slot, skip
67+
if self.record['tpm_slot'] != minute_slot:
68+
self.record = {'tpm_slot': minute_slot, 'counter': token_count}
69+
return
70+
71+
# check RPM exceed
72+
self.record['counter'] += token_count
73+
if self.record['counter'] > self.tpm:
74+
# wait until next minute
75+
next_minute = dt_object.replace(
76+
second=0, microsecond=0) + timedelta(minutes=1)
77+
_next = next_minute.timestamp()
78+
sleep_time = abs(_next - current)
79+
logger.info('TPM sleep %s', sleep_time)
80+
await asyncio.sleep(sleep_time)
81+
82+
self.record = {
83+
'tpm_slot': self.get_minute_slot(),
84+
'counter': token_count
85+
}
86+
87+
if not silent:
88+
logger.debug(self.record)

graphgen/models/llm/openai_model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
)
1212

1313
from graphgen.models.llm.topk_token_model import TopkTokenModel, Token
14-
14+
from graphgen.models.llm.tokenizer import Tokenizer
15+
from graphgen.models.llm.limitter import RPM, TPM
1516

1617
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
1718
token_logprobs = response.choices[0].logprobs.content
@@ -31,10 +32,16 @@ class OpenAIModel(TopkTokenModel):
3132
model_name: str = "gpt-4o-mini"
3233
api_key: str = None
3334
base_url: str = None
35+
3436
system_prompt: str = ""
3537
json_mode: bool = False
3638
seed: int = None
39+
3740
token_usage: list = field(default_factory=list)
41+
request_limit: bool = False
42+
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
43+
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
44+
3845

3946
def __post_init__(self):
4047
assert self.api_key is not None, "Please provide api key to access openai api."
@@ -63,6 +70,7 @@ def _pre_generate(self, text: str, history: List[str]) -> Dict:
6370
kwargs['messages']= messages
6471
return kwargs
6572

73+
6674
@retry(
6775
stop=stop_after_attempt(5),
6876
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -95,6 +103,15 @@ async def generate_answer(self, text: str, history: Optional[List[str]] = None,
95103
kwargs = self._pre_generate(text, history)
96104
kwargs["temperature"] = temperature
97105

106+
prompt_tokens = 0
107+
for message in kwargs['messages']:
108+
prompt_tokens += len(Tokenizer().encode_string(message['content']))
109+
estimated_tokens = prompt_tokens + kwargs['max_tokens']
110+
111+
if self.request_limit:
112+
await self.rpm.wait(silent=True)
113+
await self.tpm.wait(estimated_tokens, silent=True)
114+
98115
completion = await self.client.chat.completions.create( # pylint: disable=E1125
99116
model=self.model_name,
100117
**kwargs

webui/app.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from graphgen.graphgen import GraphGen
1919
from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
20+
from graphgen.models.llm.limitter import RPM, TPM
2021

2122
css = """
2223
.center-row {
@@ -38,12 +39,20 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
3839
graph_gen.synthesizer_llm_client = OpenAIModel(
3940
model_name=env.get("SYNTHESIZER_MODEL", ""),
4041
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
41-
api_key=env.get("SYNTHESIZER_API_KEY", ""))
42+
api_key=env.get("SYNTHESIZER_API_KEY", ""),
43+
request_limit=True,
44+
rpm= RPM(env.get("RPM", 1000)),
45+
tpm= TPM(env.get("TPM", 50000)),
46+
)
4247

4348
graph_gen.trainee_llm_client = OpenAIModel(
4449
model_name=env.get("TRAINEE_MODEL", ""),
4550
base_url=env.get("TRAINEE_BASE_URL", ""),
46-
api_key=env.get("TRAINEE_API_KEY", ""))
51+
api_key=env.get("TRAINEE_API_KEY", ""),
52+
request_limit=True,
53+
rpm= RPM(env.get("RPM", 1000)),
54+
tpm= TPM(env.get("TPM", 50000)),
55+
)
4756

4857
graph_gen.tokenizer_instance = Tokenizer(
4958
config.get("tokenizer", "cl100k_base"))
@@ -97,7 +106,9 @@ def sum_tokens(client):
97106
"TRAINEE_BASE_URL": arguments[12],
98107
"TRAINEE_MODEL": arguments[14],
99108
"SYNTHESIZER_API_KEY": arguments[15],
100-
"TRAINEE_API_KEY": arguments[15]
109+
"TRAINEE_API_KEY": arguments[15],
110+
"RPM": arguments[17],
111+
"TPM": arguments[18],
101112
}
102113

103114
# Test API connection
@@ -362,6 +373,28 @@ def sum_tokens(client):
362373
with gr.Column(scale=1):
363374
test_connection_btn = gr.Button("Test Connection")
364375

376+
with gr.Blocks():
377+
with gr.Row(equal_height=True):
378+
with gr.Column():
379+
rpm = gr.Slider(
380+
label="RPM",
381+
minimum=500,
382+
maximum=10000,
383+
value=1000,
384+
step=100,
385+
interactive=True,
386+
visible=True)
387+
with gr.Column():
388+
tpm = gr.Slider(
389+
label="TPM",
390+
minimum=5000,
391+
maximum=100000,
392+
value=50000,
393+
step=1000,
394+
interactive=True,
395+
visible=True)
396+
397+
365398
with gr.Blocks():
366399
with gr.Row(equal_height=True):
367400
with gr.Column(scale=1):
@@ -442,7 +475,7 @@ def sum_tokens(client):
442475
bidirectional, expand_method, max_extra_edges, max_tokens,
443476
max_depth, edge_sampling, isolated_node_strategy,
444477
loss_strategy, base_url, synthesizer_model, trainee_model,
445-
api_key, chunk_size, token_counter
478+
api_key, chunk_size, rpm, tpm, token_counter
446479
],
447480
outputs=[output, token_counter],
448481
)

webui/count_tokens.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from graphgen.models import Tokenizer
1010

1111
def count_tokens(file, tokenizer_name, data_frame):
12+
if not file or not os.path.exists(file):
13+
return data_frame
14+
1215
if file.endswith(".jsonl"):
1316
with open(file, "r", encoding='utf-8') as f:
1417
data = [json.loads(line) for line in f]

0 commit comments

Comments
 (0)