Skip to content

Commit 4c1ef09

Browse files
fix: use local tokenizer as an option
1 parent 9ead653 commit 4c1ef09

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

graphgen/generate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,18 @@ def main():
7272
),
7373
)
7474

75+
tokenizer_instance = Tokenizer(model_name=config["tokenizer"])
7576
synthesizer_llm_client = OpenAIModel(
7677
model_name=os.getenv("SYNTHESIZER_MODEL"),
7778
api_key=os.getenv("SYNTHESIZER_API_KEY"),
7879
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
80+
tokenizer_instance=tokenizer_instance,
7981
)
8082
trainee_llm_client = OpenAIModel(
8183
model_name=os.getenv("TRAINEE_MODEL"),
8284
api_key=os.getenv("TRAINEE_API_KEY"),
8385
base_url=os.getenv("TRAINEE_BASE_URL"),
86+
tokenizer_instance=tokenizer_instance,
8487
)
8588

8689
graph_gen = GraphGen(
@@ -89,7 +92,7 @@ def main():
8992
synthesizer_llm_client=synthesizer_llm_client,
9093
trainee_llm_client=trainee_llm_client,
9194
search_config=config["search"],
92-
tokenizer_instance=Tokenizer(model_name=config["tokenizer"]),
95+
tokenizer_instance=tokenizer_instance,
9396
)
9497

9598
graph_gen.insert(data, config["input_data_type"])

graphgen/models/llm/openai_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class OpenAIModel(TopkTokenModel):
5555
rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
5656
tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
5757

58+
tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
59+
5860
def __post_init__(self):
5961
assert self.api_key is not None, "Please provide api key to access openai api."
6062
self.client = AsyncOpenAI(
@@ -125,8 +127,9 @@ async def generate_answer(
125127

126128
prompt_tokens = 0
127129
for message in kwargs["messages"]:
128-
# TODO: need to use local tokenizer to avoid network call
129-
prompt_tokens += len(Tokenizer().encode_string(message["content"]))
130+
prompt_tokens += len(
131+
self.tokenizer_instance.encode_string(message["content"])
132+
)
130133
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
131134

132135
if self.request_limit:

0 commit comments

Comments
 (0)