Skip to content

Commit 2b70e4a

Browse files
Tokenizer (#2051)
* Add LiteLLM chat and embedding model providers. * Fix code review findings. * Add litellm. * Fix formatting. * Update dictionary. * Update litellm. * Fix embedding. * Remove manual use of tiktoken and replace with Tokenizer interface. Adds support for encoding and decoding the models supported by litellm. * Update litellm. * Configure litellm to drop unsupported params. * Cleanup semversioner release notes. * Add num_tokens util to Tokenizer interface. * Update litellm service factories. * Cleanup litellm chat/embedding model argument assignment. * Update chat and embedding type field for litellm use and future migration away from fnllm. * Flatten litellm service organization. * Update litellm. * Update litellm factory validation. * Flatten litellm rate limit service organization. * Update rate limiter - disable with None/null instead of 0. * Fix usage of get_tokenizer. * Update litellm service registrations. * Add jitter to exponential retry. * Update validation. * Update validation. * Add litellm request logging layer. * Update cache key. * Update defaults. --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent 82cd3b7 commit 2b70e4a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+5305
-1907
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Add LiteLLM chat and embedding model providers."
4+
}

dictionary.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ typer
8181
spacy
8282
kwargs
8383
ollama
84+
litellm
8485

8586
# Library Methods
8687
iterrows
@@ -103,6 +104,8 @@ isin
103104
nocache
104105
nbconvert
105106
levelno
107+
acompletion
108+
aembedding
106109

107110
# HTML
108111
nbsp

graphrag/api/prompt_tune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from graphrag.prompt_tune.generator.persona import generate_persona
4848
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
4949
from graphrag.prompt_tune.types import DocSelectionType
50+
from graphrag.tokenizer.get_tokenizer import get_tokenizer
5051

5152
logger = logging.getLogger(__name__)
5253

@@ -166,7 +167,7 @@ async def generate_indexing_prompts(
166167
examples=examples,
167168
language=language,
168169
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
169-
encoding_model=extract_graph_llm_settings.encoding_model,
170+
tokenizer=get_tokenizer(model_config=extract_graph_llm_settings),
170171
max_token_count=max_tokens,
171172
min_examples_required=min_examples_required,
172173
)

graphrag/config/defaults.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Common default configuration values."""
55

6+
from collections.abc import Callable
67
from dataclasses import dataclass, field
78
from pathlib import Path
89
from typing import ClassVar, Literal
@@ -23,6 +24,25 @@
2324
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
2425
EN_STOP_WORDS,
2526
)
27+
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
28+
RateLimiter,
29+
)
30+
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
31+
StaticRateLimiter,
32+
)
33+
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
34+
ExponentialRetry,
35+
)
36+
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
37+
IncrementalWaitRetry,
38+
)
39+
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
40+
NativeRetry,
41+
)
42+
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
43+
RandomWaitRetry,
44+
)
45+
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
2646

2747
DEFAULT_OUTPUT_BASE_DIR = "output"
2848
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
@@ -39,6 +59,18 @@
3959
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
4060

4161

62+
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
63+
"native": NativeRetry,
64+
"exponential_backoff": ExponentialRetry,
65+
"random_wait": RandomWaitRetry,
66+
"incremental_wait": IncrementalWaitRetry,
67+
}
68+
69+
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
70+
"static": StaticRateLimiter,
71+
}
72+
73+
4274
@dataclass
4375
class BasicSearchDefaults:
4476
"""Default values for basic search."""
@@ -275,6 +307,7 @@ class LanguageModelDefaults:
275307

276308
api_key: None = None
277309
auth_type: ClassVar[AuthType] = AuthType.APIKey
310+
model_provider: str | None = None
278311
encoding_model: str = ""
279312
max_tokens: int | None = None
280313
temperature: float = 0
@@ -294,6 +327,7 @@ class LanguageModelDefaults:
294327
model_supports_json: None = None
295328
tokens_per_minute: Literal["auto"] = "auto"
296329
requests_per_minute: Literal["auto"] = "auto"
330+
rate_limit_strategy: str | None = "static"
297331
retry_strategy: str = "native"
298332
max_retries: int = 10
299333
max_retry_wait: float = 10.0

graphrag/config/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ class ModelType(str, Enum):
8686
# Embeddings
8787
OpenAIEmbedding = "openai_embedding"
8888
AzureOpenAIEmbedding = "azure_openai_embedding"
89+
Embedding = "embedding"
8990

9091
# Chat Completion
9192
OpenAIChat = "openai_chat"
9293
AzureOpenAIChat = "azure_openai_chat"
94+
Chat = "chat"
9395

9496
# Debug
9597
MockChat = "mock_chat"

graphrag/config/models/graph_rag_config.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
3838
from graphrag.config.models.umap_config import UmapConfig
3939
from graphrag.config.models.vector_store_config import VectorStoreConfig
40+
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
41+
RateLimiterFactory,
42+
)
43+
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
44+
RetryFactory,
45+
)
4046

4147

4248
class GraphRagConfig(BaseModel):
@@ -89,6 +95,47 @@ def _validate_models(self) -> None:
8995
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
9096
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
9197

98+
def _validate_retry_services(self) -> None:
99+
"""Validate the retry services configuration."""
100+
retry_factory = RetryFactory()
101+
102+
for model_id, model in self.models.items():
103+
if model.retry_strategy != "none":
104+
if model.retry_strategy not in retry_factory:
105+
msg = f"Retry strategy '{model.retry_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(retry_factory.keys())}"
106+
raise ValueError(msg)
107+
108+
_ = retry_factory.create(
109+
strategy=model.retry_strategy,
110+
max_attempts=model.max_retries,
111+
max_retry_wait=model.max_retry_wait,
112+
)
113+
114+
def _validate_rate_limiter_services(self) -> None:
115+
"""Validate the rate limiter services configuration."""
116+
rate_limiter_factory = RateLimiterFactory()
117+
118+
for model_id, model in self.models.items():
119+
if model.rate_limit_strategy is not None:
120+
if model.rate_limit_strategy not in rate_limiter_factory:
121+
msg = f"Rate Limiter strategy '{model.rate_limit_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
122+
raise ValueError(msg)
123+
124+
rpm = (
125+
model.requests_per_minute
126+
if type(model.requests_per_minute) is int
127+
else None
128+
)
129+
tpm = (
130+
model.tokens_per_minute
131+
if type(model.tokens_per_minute) is int
132+
else None
133+
)
134+
if rpm is not None or tpm is not None:
135+
_ = rate_limiter_factory.create(
136+
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
137+
)
138+
92139
input: InputConfig = Field(
93140
description="The input configuration.", default=InputConfig()
94141
)
@@ -300,6 +347,11 @@ def _validate_vector_store_db_uri(self) -> None:
300347
raise ValueError(msg)
301348
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
302349

350+
def _validate_factories(self) -> None:
351+
"""Validate the factories used in the configuration."""
352+
self._validate_retry_services()
353+
self._validate_rate_limiter_services()
354+
303355
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
304356
"""Get a model configuration by ID.
305357
@@ -360,4 +412,5 @@ def _validate_model(self):
360412
self._validate_multi_output_base_dirs()
361413
self._validate_update_index_output_base_dir()
362414
self._validate_vector_store_db_uri()
415+
self._validate_factories()
363416
return self

graphrag/config/models/language_model_config.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ def _validate_auth_type(self) -> None:
7373
ConflictingSettingsError
7474
If the Azure authentication type conflicts with the model being used.
7575
"""
76-
if self.auth_type == AuthType.AzureManagedIdentity and (
77-
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
76+
if (
77+
self.auth_type == AuthType.AzureManagedIdentity
78+
and self.type != ModelType.AzureOpenAIChat
79+
and self.type != ModelType.AzureOpenAIEmbedding
80+
and self.model_provider != "azure" # indicates Litellm + AOI
7881
):
7982
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
8083
raise ConflictingSettingsError(msg)
@@ -94,6 +97,27 @@ def _validate_type(self) -> None:
9497
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
9598
raise KeyError(msg)
9699

100+
model_provider: str | None = Field(
101+
description="The model provider to use.",
102+
default=language_model_defaults.model_provider,
103+
)
104+
105+
def _validate_model_provider(self) -> None:
106+
"""Validate the model provider.
107+
108+
Required when using Litellm.
109+
110+
Raises
111+
------
112+
KeyError
113+
If the model provider is not recognized.
114+
"""
115+
if (self.type == ModelType.Chat or self.type == ModelType.Embedding) and (
116+
self.model_provider is None or self.model_provider.strip() == ""
117+
):
118+
msg = f"Model provider must be specified when using type == {self.type}."
119+
raise KeyError(msg)
120+
97121
model: str = Field(description="The LLM model to use.")
98122
encoding_model: str = Field(
99123
description="The encoding model to use",
@@ -103,12 +127,27 @@ def _validate_type(self) -> None:
103127
def _validate_encoding_model(self) -> None:
104128
"""Validate the encoding model.
105129
130+
The default behavior is to use an encoding model that matches the LLM model.
131+
LiteLLM supports 100+ models and their tokenization. There is no need to
132+
set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
133+
134+
Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
135+
in which case the specified encoding model will be used regardless of the LLM model being used, even if
136+
it is not an openai based model.
137+
138+
If not using LiteLLM provider, set the encoding model based on the LLM model name.
139+
This is for backward compatibility with existing fnllm provider until fnllm is removed.
140+
106141
Raises
107142
------
108143
KeyError
109144
If the model name is not recognized.
110145
"""
111-
if self.encoding_model.strip() == "":
146+
if (
147+
self.type != ModelType.Chat
148+
and self.type != ModelType.Embedding
149+
and self.encoding_model.strip() == ""
150+
):
112151
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
113152

114153
api_base: str | None = Field(
@@ -129,6 +168,7 @@ def _validate_api_base(self) -> None:
129168
if (
130169
self.type == ModelType.AzureOpenAIChat
131170
or self.type == ModelType.AzureOpenAIEmbedding
171+
or self.model_provider == "azure" # indicates Litellm + AOI
132172
) and (self.api_base is None or self.api_base.strip() == ""):
133173
raise AzureApiBaseMissingError(self.type)
134174

@@ -150,6 +190,7 @@ def _validate_api_version(self) -> None:
150190
if (
151191
self.type == ModelType.AzureOpenAIChat
152192
or self.type == ModelType.AzureOpenAIEmbedding
193+
or self.model_provider == "azure" # indicates Litellm + AOI
153194
) and (self.api_version is None or self.api_version.strip() == ""):
154195
raise AzureApiVersionMissingError(self.type)
155196

@@ -171,6 +212,7 @@ def _validate_deployment_name(self) -> None:
171212
if (
172213
self.type == ModelType.AzureOpenAIChat
173214
or self.type == ModelType.AzureOpenAIEmbedding
215+
or self.model_provider == "azure" # indicates Litellm + AOI
174216
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
175217
raise AzureDeploymentNameMissingError(self.type)
176218

@@ -212,6 +254,14 @@ def _validate_tokens_per_minute(self) -> None:
212254
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
213255
raise ValueError(msg)
214256

257+
if (
258+
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
259+
and self.rate_limit_strategy is not None
260+
and self.tokens_per_minute == "auto"
261+
):
262+
msg = f"tokens_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
263+
raise ValueError(msg)
264+
215265
requests_per_minute: int | Literal["auto"] | None = Field(
216266
description="The number of requests per minute to use for the LLM service.",
217267
default=language_model_defaults.requests_per_minute,
@@ -230,6 +280,19 @@ def _validate_requests_per_minute(self) -> None:
230280
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
231281
raise ValueError(msg)
232282

283+
if (
284+
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
285+
and self.rate_limit_strategy is not None
286+
and self.requests_per_minute == "auto"
287+
):
288+
msg = f"requests_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
289+
raise ValueError(msg)
290+
291+
rate_limit_strategy: str | None = Field(
292+
description="The rate limit strategy to use for the LLM service.",
293+
default=language_model_defaults.rate_limit_strategy,
294+
)
295+
233296
retry_strategy: str = Field(
234297
description="The retry strategy to use for the LLM service.",
235298
default=language_model_defaults.retry_strategy,
@@ -318,6 +381,7 @@ def _validate_azure_settings(self) -> None:
318381
@model_validator(mode="after")
319382
def _validate_model(self):
320383
self._validate_type()
384+
self._validate_model_provider()
321385
self._validate_auth_type()
322386
self._validate_api_key()
323387
self._validate_tokens_per_minute()

graphrag/factory/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) 2025 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Factory module."""

0 commit comments

Comments
 (0)