Skip to content

Commit 624331c

Browse files
committed
Update chat and embedding type field for litellm use and future migration away from fnllm.
1 parent 7a06d61 commit 624331c

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

graphrag/config/enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ class ModelType(str, Enum):
8686
# Embeddings
8787
OpenAIEmbedding = "openai_embedding"
8888
AzureOpenAIEmbedding = "azure_openai_embedding"
89-
LitellmEmbedding = "litellm_embedding"
89+
Embedding = "embedding"
9090

9191
# Chat Completion
9292
OpenAIChat = "openai_chat"
9393
AzureOpenAIChat = "azure_openai_chat"
94-
LitellmChat = "litellm_chat"
94+
Chat = "chat"
9595

9696
# Debug
9797
MockChat = "mock_chat"

graphrag/config/models/language_model_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def _validate_model_provider(self) -> None:
113113
If the model provider is not recognized.
114114
"""
115115
if (
116-
self.type == ModelType.LitellmChat
117-
or self.type == ModelType.LitellmEmbedding
116+
self.type == ModelType.Chat or self.type == ModelType.Embedding
118117
) and self.model_provider.strip() == "":
119118
msg = "Model provider must be specified when using Litellm."
120119
raise KeyError(msg)
@@ -145,8 +144,8 @@ def _validate_encoding_model(self) -> None:
145144
If the model name is not recognized.
146145
"""
147146
if (
148-
self.type != ModelType.LitellmChat
149-
and self.type != ModelType.LitellmEmbedding
147+
self.type != ModelType.Chat
148+
and self.type != ModelType.Embedding
150149
and self.encoding_model.strip() == ""
151150
):
152151
self.encoding_model = tiktoken.encoding_name_for_model(self.model)

graphrag/language_model/factory.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ def is_supported_model(cls, model_type: str) -> bool:
109109
ModelFactory.register_chat(
110110
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
111111
)
112-
ModelFactory.register_chat(
113-
ModelType.LitellmChat, lambda **kwargs: LitellmChatModel(**kwargs)
114-
)
112+
ModelFactory.register_chat(ModelType.Chat, lambda **kwargs: LitellmChatModel(**kwargs))
115113

116114
ModelFactory.register_embedding(
117115
ModelType.AzureOpenAIEmbedding.value,
@@ -121,5 +119,5 @@ def is_supported_model(cls, model_type: str) -> bool:
121119
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
122120
)
123121
ModelFactory.register_embedding(
124-
ModelType.LitellmEmbedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs)
122+
ModelType.Embedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs)
125123
)

0 commit comments

Comments
 (0)