Skip to content

Commit 7bdeaee

Browse files
Create Language Model Providers and Registry methods. Remove fnllm coupling (#1724)
* Base structure * Add fnllm providers and Mock LLM * Remove fnllm coupling, introduce llm providers * Ruff + Tests fix * Spellcheck * Semver * Format * Default MockChat params * Fix more tests * Fix embedding smoke test * Fix embeddings smoke test * Fix MockEmbeddingLLM * Rename LLM to model. Package organization * Fix prompt tuning * Oops * Oops II
1 parent a42772d commit 7bdeaee

File tree

106 files changed

+1210
-687
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+1210
-687
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Add LMM Manager and Factory, to support provider registration"
4+
}

graphrag/api/prompt_tune.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import graphrag.config.defaults as defs
1717
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1818
from graphrag.config.models.graph_rag_config import GraphRagConfig
19-
from graphrag.index.llm.load_llm import load_llm
19+
from graphrag.language_model.manager import ModelManager
2020
from graphrag.logger.print_progress import PrintProgressLogger
2121
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
2222
from graphrag.prompt_tune.generator.community_report_rating import (
@@ -104,11 +104,12 @@ async def generate_indexing_prompts(
104104
if default_llm_settings.max_retries == -1:
105105
default_llm_settings.max_retries = min(len(doc_list), defs.LLM_MAX_RETRIES)
106106

107-
llm = load_llm(
108-
"prompt_tuning",
109-
default_llm_settings,
110-
cache=None,
107+
llm = ModelManager().register_chat(
108+
name="prompt_tuning",
109+
model_type=default_llm_settings.type,
110+
config=default_llm_settings,
111111
callbacks=NoopWorkflowCallbacks(),
112+
cache=None,
112113
)
113114

114115
if not domain:

graphrag/cli/index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _run_index(
153153
True,
154154
)
155155

156-
if skip_validation:
156+
if not skip_validation:
157157
validate_config_names(progress_logger, config)
158158

159159
info(f"Starting pipeline run. {dry_run=}", verbose)

graphrag/config/defaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ChunkStrategyType,
1313
InputFileType,
1414
InputType,
15-
LLMType,
15+
ModelType,
1616
NounPhraseExtractorType,
1717
OutputType,
1818
ReportingType,
@@ -30,7 +30,7 @@
3030
# LLM Parameters
3131
#
3232
LLM_FREQUENCY_PENALTY = 0.0
33-
LLM_TYPE = LLMType.OpenAIChat
33+
LLM_TYPE = ModelType.OpenAIChat
3434
LLM_MODEL = "gpt-4-turbo-preview"
3535
LLM_MAX_TOKENS = 4000
3636
LLM_TEMPERATURE = 0
@@ -48,7 +48,7 @@
4848
#
4949
# Text embedding
5050
#
51-
EMBEDDING_TYPE = LLMType.OpenAIEmbedding
51+
EMBEDDING_TYPE = ModelType.OpenAIEmbedding
5252
EMBEDDING_MODEL = "text-embedding-3-small"
5353
EMBEDDING_BATCH_SIZE = 16
5454
EMBEDDING_BATCH_MAX_TOKENS = 8191

graphrag/config/enums.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __repr__(self):
9898
return f'"{self.value}"'
9999

100100

101-
class LLMType(str, Enum):
101+
class ModelType(str, Enum):
102102
"""LLMType enum class definition."""
103103

104104
# Embeddings
@@ -110,7 +110,8 @@ class LLMType(str, Enum):
110110
AzureOpenAIChat = "azure_openai_chat"
111111

112112
# Debug
113-
StaticResponse = "static_response"
113+
MockChat = "mock_chat"
114+
MockEmbedding = "mock_embedding"
114115

115116
def __repr__(self):
116117
"""Get a string representation."""

graphrag/config/models/language_model_config.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import BaseModel, Field, model_validator
88

99
import graphrag.config.defaults as defs
10-
from graphrag.config.enums import AsyncType, AuthType, LLMType
10+
from graphrag.config.enums import AsyncType, AuthType, ModelType
1111
from graphrag.config.errors import (
1212
ApiKeyMissingError,
1313
AzureApiBaseMissingError,
@@ -71,12 +71,12 @@ def _validate_auth_type(self) -> None:
7171
If the Azure authentication type conflicts with the model being used.
7272
"""
7373
if self.auth_type == AuthType.AzureManagedIdentity and (
74-
self.type == LLMType.OpenAIChat or self.type == LLMType.OpenAIEmbedding
74+
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
7575
):
7676
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
7777
raise ConflictingSettingsError(msg)
7878

79-
type: LLMType = Field(description="The type of LLM model to use.")
79+
type: ModelType = Field(description="The type of LLM model to use.")
8080
model: str = Field(description="The LLM model to use.")
8181
encoding_model: str = Field(description="The encoding model to use", default="")
8282

@@ -133,8 +133,8 @@ def _validate_api_base(self) -> None:
133133
If the API base is missing and is required.
134134
"""
135135
if (
136-
self.type == LLMType.AzureOpenAIChat
137-
or self.type == LLMType.AzureOpenAIEmbedding
136+
self.type == ModelType.AzureOpenAIChat
137+
or self.type == ModelType.AzureOpenAIEmbedding
138138
) and (self.api_base is None or self.api_base.strip() == ""):
139139
raise AzureApiBaseMissingError(self.type.value)
140140

@@ -153,8 +153,8 @@ def _validate_api_version(self) -> None:
153153
If the API base is missing and is required.
154154
"""
155155
if (
156-
self.type == LLMType.AzureOpenAIChat
157-
or self.type == LLMType.AzureOpenAIEmbedding
156+
self.type == ModelType.AzureOpenAIChat
157+
or self.type == ModelType.AzureOpenAIEmbedding
158158
) and (self.api_version is None or self.api_version.strip() == ""):
159159
raise AzureApiVersionMissingError(self.type.value)
160160

@@ -173,8 +173,8 @@ def _validate_deployment_name(self) -> None:
173173
If the deployment name is missing and is required.
174174
"""
175175
if (
176-
self.type == LLMType.AzureOpenAIChat
177-
or self.type == LLMType.AzureOpenAIEmbedding
176+
self.type == ModelType.AzureOpenAIChat
177+
or self.type == ModelType.AzureOpenAIEmbedding
178178
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
179179
raise AzureDeploymentNameMissingError(self.type.value)
180180

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any
88

9-
from graphrag.model.named import Named
9+
from graphrag.data_model.named import Named
1010

1111

1212
@dataclass
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any
88

9-
from graphrag.model.named import Named
9+
from graphrag.data_model.named import Named
1010

1111

1212
@dataclass
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any
88

9-
from graphrag.model.identified import Identified
9+
from graphrag.data_model.identified import Identified
1010

1111

1212
@dataclass

0 commit comments

Comments
 (0)