@@ -73,8 +73,11 @@ def _validate_auth_type(self) -> None:
7373 ConflictingSettingsError
7474 If the Azure authentication type conflicts with the model being used.
7575 """
76- if self .auth_type == AuthType .AzureManagedIdentity and (
77- self .type == ModelType .OpenAIChat or self .type == ModelType .OpenAIEmbedding
76+ if (
77+ self .auth_type == AuthType .AzureManagedIdentity
78+ and self .type != ModelType .AzureOpenAIChat
79+ and self .type != ModelType .AzureOpenAIEmbedding
80+ and self .model_provider != "azure" # indicates Litellm + AOI
7881 ):
7982 msg = f"auth_type of azure_managed_identity is not supported for model type { self .type } . Please rerun `graphrag init` and set the auth_type to api_key."
8083 raise ConflictingSettingsError (msg )
@@ -94,6 +97,27 @@ def _validate_type(self) -> None:
9497 msg = f"Model type { self .type } is not recognized, must be one of { ModelFactory .get_chat_models () + ModelFactory .get_embedding_models ()} ."
9598 raise KeyError (msg )
9699
100+ model_provider : str | None = Field (
101+ description = "The model provider to use." ,
102+ default = language_model_defaults .model_provider ,
103+ )
104+
105+ def _validate_model_provider (self ) -> None :
106+ """Validate the model provider.
107+
108+ Required when using Litellm.
109+
110+ Raises
111+ ------
112+ KeyError
113+ If the model provider is not recognized.
114+ """
115+ if (self .type == ModelType .Chat or self .type == ModelType .Embedding ) and (
116+ self .model_provider is None or self .model_provider .strip () == ""
117+ ):
118+ msg = f"Model provider must be specified when using type == { self .type } ."
119+ raise KeyError (msg )
120+
97121 model : str = Field (description = "The LLM model to use." )
98122 encoding_model : str = Field (
99123 description = "The encoding model to use" ,
@@ -103,12 +127,27 @@ def _validate_type(self) -> None:
103127 def _validate_encoding_model (self ) -> None :
104128 """Validate the encoding model.
105129
130+ The default behavior is to use an encoding model that matches the LLM model.
131+ LiteLLM supports 100+ models and their tokenization. There is no need to
132+ set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
133+
134+ Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
135+ in which case the specified encoding model will be used regardless of the LLM model being used, even if
136+ it is not an openai based model.
137+
138+ If not using LiteLLM provider, set the encoding model based on the LLM model name.
139+ This is for backward compatibility with existing fnllm provider until fnllm is removed.
140+
106141 Raises
107142 ------
108143 KeyError
109144 If the model name is not recognized.
110145 """
111- if self .encoding_model .strip () == "" :
146+ if (
147+ self .type != ModelType .Chat
148+ and self .type != ModelType .Embedding
149+ and self .encoding_model .strip () == ""
150+ ):
112151 self .encoding_model = tiktoken .encoding_name_for_model (self .model )
113152
114153 api_base : str | None = Field (
@@ -129,6 +168,7 @@ def _validate_api_base(self) -> None:
129168 if (
130169 self .type == ModelType .AzureOpenAIChat
131170 or self .type == ModelType .AzureOpenAIEmbedding
171+ or self .model_provider == "azure" # indicates Litellm + AOI
132172 ) and (self .api_base is None or self .api_base .strip () == "" ):
133173 raise AzureApiBaseMissingError (self .type )
134174
@@ -150,6 +190,7 @@ def _validate_api_version(self) -> None:
150190 if (
151191 self .type == ModelType .AzureOpenAIChat
152192 or self .type == ModelType .AzureOpenAIEmbedding
193+ or self .model_provider == "azure" # indicates Litellm + AOI
153194 ) and (self .api_version is None or self .api_version .strip () == "" ):
154195 raise AzureApiVersionMissingError (self .type )
155196
@@ -171,6 +212,7 @@ def _validate_deployment_name(self) -> None:
171212 if (
172213 self .type == ModelType .AzureOpenAIChat
173214 or self .type == ModelType .AzureOpenAIEmbedding
215+ or self .model_provider == "azure" # indicates Litellm + AOI
174216 ) and (self .deployment_name is None or self .deployment_name .strip () == "" ):
175217 raise AzureDeploymentNameMissingError (self .type )
176218
@@ -212,6 +254,14 @@ def _validate_tokens_per_minute(self) -> None:
212254 msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: { language_model_defaults .tokens_per_minute } ."
213255 raise ValueError (msg )
214256
257+ if (
258+ (self .type == ModelType .Chat or self .type == ModelType .Embedding )
259+ and self .rate_limit_strategy is not None
260+ and self .tokens_per_minute == "auto"
261+ ):
262+ msg = f"tokens_per_minute cannot be set to 'auto' when using type '{ self .type } '. Please set it to a positive integer or null to disable."
263+ raise ValueError (msg )
264+
215265 requests_per_minute : int | Literal ["auto" ] | None = Field (
216266 description = "The number of requests per minute to use for the LLM service." ,
217267 default = language_model_defaults .requests_per_minute ,
@@ -230,6 +280,19 @@ def _validate_requests_per_minute(self) -> None:
230280 msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: { language_model_defaults .requests_per_minute } ."
231281 raise ValueError (msg )
232282
283+ if (
284+ (self .type == ModelType .Chat or self .type == ModelType .Embedding )
285+ and self .rate_limit_strategy is not None
286+ and self .requests_per_minute == "auto"
287+ ):
288+ msg = f"requests_per_minute cannot be set to 'auto' when using type '{ self .type } '. Please set it to a positive integer or null to disable."
289+ raise ValueError (msg )
290+
291+ rate_limit_strategy : str | None = Field (
292+ description = "The rate limit strategy to use for the LLM service." ,
293+ default = language_model_defaults .rate_limit_strategy ,
294+ )
295+
233296 retry_strategy : str = Field (
234297 description = "The retry strategy to use for the LLM service." ,
235298 default = language_model_defaults .retry_strategy ,
@@ -318,6 +381,7 @@ def _validate_azure_settings(self) -> None:
318381 @model_validator (mode = "after" )
319382 def _validate_model (self ):
320383 self ._validate_type ()
384+ self ._validate_model_provider ()
321385 self ._validate_auth_type ()
322386 self ._validate_api_key ()
323387 self ._validate_tokens_per_minute ()
0 commit comments