Skip to content

Commit 6c86b0a

Browse files
authored
Init config cleanup (#2084)
* Spruce up init_config output, including LiteLLM default * Remove deployment_name requirement for Azure * Semver * Add model_provider * Add default model_provider * Remove OBE test * Update minimal config for tests * Add model_provider to verb tests
1 parent 2bd3922 commit 6c86b0a

File tree

10 files changed

+46
-61
lines changed

10 files changed

+46
-61
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Set LiteLLM as default in init_content."
4+
}

graphrag/config/defaults.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Callable
77
from dataclasses import dataclass, field
88
from pathlib import Path
9-
from typing import ClassVar, Literal
9+
from typing import ClassVar
1010

1111
from graphrag.config.embeddings import default_embeddings
1212
from graphrag.config.enums import (
@@ -46,13 +46,14 @@
4646

4747
DEFAULT_OUTPUT_BASE_DIR = "output"
4848
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
49-
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
49+
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
5050
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
5151
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
5252
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
53-
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
53+
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
5454
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
5555
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
56+
DEFAULT_MODEL_PROVIDER = "openai"
5657
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
5758

5859
ENCODING_MODEL = "cl100k_base"
@@ -325,10 +326,10 @@ class LanguageModelDefaults:
325326
proxy: None = None
326327
audience: None = None
327328
model_supports_json: None = None
328-
tokens_per_minute: Literal["auto"] = "auto"
329-
requests_per_minute: Literal["auto"] = "auto"
329+
tokens_per_minute: None = None
330+
requests_per_minute: None = None
330331
rate_limit_strategy: str | None = "static"
331-
retry_strategy: str = "native"
332+
retry_strategy: str = "exponential_backoff"
332333
max_retries: int = 10
333334
max_retry_wait: float = 10.0
334335
concurrent_requests: int = 25

graphrag/config/errors.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@ def __init__(self, llm_type: str) -> None:
3333
super().__init__(msg)
3434

3535

36-
class AzureDeploymentNameMissingError(ValueError):
37-
"""Azure Deployment Name missing error."""
38-
39-
def __init__(self, llm_type: str) -> None:
40-
"""Init method definition."""
41-
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
42-
super().__init__(msg)
43-
44-
4536
class LanguageModelConfigMissingError(ValueError):
4637
"""Missing model configuration error."""
4738

graphrag/config/init_content.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,34 @@
1919
2020
models:
2121
{defs.DEFAULT_CHAT_MODEL_ID}:
22-
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
23-
# api_base: https://<instance>.openai.azure.com
24-
# api_version: 2024-05-01-preview
22+
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
23+
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
2524
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
26-
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
27-
# audience: "https://cognitiveservices.azure.com/.default"
28-
# organization: <organization_id>
25+
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
2926
model: {defs.DEFAULT_CHAT_MODEL}
30-
# deployment_name: <azure_model_deployment_name>
31-
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
27+
# api_base: https://<instance>.openai.azure.com
28+
# api_version: 2024-05-01-preview
3229
model_supports_json: true # recommended if this is available for your model.
33-
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
30+
concurrent_requests: {language_model_defaults.concurrent_requests}
3431
async_mode: {language_model_defaults.async_mode.value} # or asyncio
35-
retry_strategy: native
32+
retry_strategy: {language_model_defaults.retry_strategy}
3633
max_retries: {language_model_defaults.max_retries}
37-
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
38-
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
34+
tokens_per_minute: null
35+
requests_per_minute: null
3936
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
40-
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
41-
# api_base: https://<instance>.openai.azure.com
42-
# api_version: 2024-05-01-preview
43-
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
37+
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
38+
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
39+
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
4440
api_key: ${{GRAPHRAG_API_KEY}}
45-
# audience: "https://cognitiveservices.azure.com/.default"
46-
# organization: <organization_id>
4741
model: {defs.DEFAULT_EMBEDDING_MODEL}
48-
# deployment_name: <azure_model_deployment_name>
49-
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
50-
model_supports_json: true # recommended if this is available for your model.
51-
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
42+
# api_base: https://<instance>.openai.azure.com
43+
# api_version: 2024-05-01-preview
44+
concurrent_requests: {language_model_defaults.concurrent_requests}
5245
async_mode: {language_model_defaults.async_mode.value} # or asyncio
53-
retry_strategy: native
46+
retry_strategy: {language_model_defaults.retry_strategy}
5447
max_retries: {language_model_defaults.max_retries}
55-
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic
56-
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic
48+
tokens_per_minute: null
49+
requests_per_minute: null
5750
5851
### Input settings ###
5952
@@ -62,7 +55,6 @@
6255
type: {graphrag_config_defaults.input.storage.type.value} # or blob
6356
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
6457
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
65-
6658
6759
chunks:
6860
size: {graphrag_config_defaults.chunks.size}
@@ -90,7 +82,6 @@
9082
type: {vector_store_defaults.type}
9183
db_uri: {vector_store_defaults.db_uri}
9284
container_name: {vector_store_defaults.container_name}
93-
overwrite: {vector_store_defaults.overwrite}
9485
9586
### Workflow settings ###
9687

graphrag/config/models/language_model_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Language model configuration."""
55

6+
import logging
67
from typing import Literal
78

89
import tiktoken
@@ -14,11 +15,12 @@
1415
ApiKeyMissingError,
1516
AzureApiBaseMissingError,
1617
AzureApiVersionMissingError,
17-
AzureDeploymentNameMissingError,
1818
ConflictingSettingsError,
1919
)
2020
from graphrag.language_model.factory import ModelFactory
2121

22+
logger = logging.getLogger(__name__)
23+
2224

2325
class LanguageModelConfig(BaseModel):
2426
"""Language model configuration."""
@@ -214,7 +216,8 @@ def _validate_deployment_name(self) -> None:
214216
or self.type == ModelType.AzureOpenAIEmbedding
215217
or self.model_provider == "azure" # indicates Litellm + AOI
216218
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
217-
raise AzureDeploymentNameMissingError(self.type)
219+
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
220+
logger.debug(msg)
218221

219222
organization: str | None = Field(
220223
description="The organization to use for the LLM service.",
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
models:
22
default_chat_model:
33
api_key: ${CUSTOM_API_KEY}
4-
type: openai_chat
4+
type: chat
5+
model_provider: openai
56
model: gpt-4-turbo-preview
67
default_embedding_model:
78
api_key: ${CUSTOM_API_KEY}
8-
type: openai_embedding
9+
type: embedding
10+
model_provider: openai
911
model: text-embedding-3-small
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
models:
22
default_chat_model:
33
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
4-
type: openai_chat
4+
type: chat
5+
model_provider: openai
56
model: gpt-4-turbo-preview
67
default_embedding_model:
78
api_key: ${SOME_NON_EXISTENT_ENV_VAR}
8-
type: openai_embedding
9+
type: embedding
10+
model_provider: openai
911
model: text-embedding-3-small

tests/unit/config/test_config.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,6 @@ def test_missing_azure_api_version() -> None:
133133
})
134134

135135

136-
def test_missing_azure_deployment_name() -> None:
137-
missing_deployment_name_config = base_azure_model_config.copy()
138-
del missing_deployment_name_config["deployment_name"]
139-
140-
with pytest.raises(ValidationError):
141-
create_graphrag_config({
142-
"models": {
143-
defs.DEFAULT_CHAT_MODEL_ID: missing_deployment_name_config,
144-
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
145-
}
146-
})
147-
148-
149136
def test_default_config() -> None:
150137
expected = get_default_graphrag_config()
151138
actual = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})

tests/unit/config/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@
4141
"api_key": FAKE_API_KEY,
4242
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
4343
"model": defs.DEFAULT_CHAT_MODEL,
44+
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
4445
}
4546

4647
DEFAULT_EMBEDDING_MODEL_CONFIG = {
4748
"api_key": FAKE_API_KEY,
4849
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
4950
"model": defs.DEFAULT_EMBEDDING_MODEL,
51+
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
5052
}
5153

5254
DEFAULT_MODEL_CONFIG = {

tests/verbs/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
"api_key": FAKE_API_KEY,
1818
"type": defs.DEFAULT_CHAT_MODEL_TYPE.value,
1919
"model": defs.DEFAULT_CHAT_MODEL,
20+
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
2021
}
2122

2223
DEFAULT_EMBEDDING_MODEL_CONFIG = {
2324
"api_key": FAKE_API_KEY,
2425
"type": defs.DEFAULT_EMBEDDING_MODEL_TYPE.value,
2526
"model": defs.DEFAULT_EMBEDDING_MODEL,
27+
"model_provider": defs.DEFAULT_MODEL_PROVIDER,
2628
}
2729

2830
DEFAULT_MODEL_CONFIG = {

0 commit comments

Comments
 (0)