Skip to content

Commit e0d233f

Browse files
Feat/llm provider query (#1735)
* Add ModelProvider to Query package. * Spellcheck + others * Semver * Fix tests * Format * Fix Pyright * Fix tests * Fix for smoke tests
1 parent faa05b6 commit e0d233f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+911
-1363
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Use ModelProvider for query module"
4+
}

dictionary.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,16 @@ unhot
120120
groupby
121121
retryer
122122
agenerate
123-
aembed
124-
dedupe
125123
dropna
126-
dtypes
127124
notna
128125

129126
# LLM Terms
130127
AOAI
131128
embedder
132129
llm
133130
llms
131+
achat
132+
aembed
134133

135134
# Galaxy-Brain Terms
136135
Unipartite

graphrag/callbacks/query_callbacks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ def on_reduce_response_start(
2828

2929
def on_reduce_response_end(self, reduce_response_output: str) -> None:
3030
"""Handle the end of reduce operation."""
31+
32+
def on_llm_new_token(self, token) -> None:
33+
"""Handle when a new token is generated."""

graphrag/config/defaults.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class BasicSearchDefaults:
4848
n: int = 1
4949
max_tokens: int = 12_000
5050
llm_max_tokens: int = 2000
51+
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
52+
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
5153

5254

5355
@dataclass
@@ -122,7 +124,9 @@ class DriftSearchDefaults:
122124
local_search_temperature: float = 0
123125
local_search_top_p: float = 1
124126
local_search_n: int = 1
125-
local_search_llm_max_gen_tokens: int = 12_000
127+
local_search_llm_max_gen_tokens: int = 4_096
128+
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
129+
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
126130

127131

128132
@dataclass
@@ -239,6 +243,8 @@ class GlobalSearchDefaults:
239243
dynamic_search_use_summary: bool = False
240244
dynamic_search_concurrent_coroutines: int = 16
241245
dynamic_search_max_level: int = 2
246+
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
247+
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
242248

243249

244250
@dataclass
@@ -305,6 +311,8 @@ class LocalSearchDefaults:
305311
n: int = 1
306312
max_tokens: int = 12_000
307313
llm_max_tokens: int = 2000
314+
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
315+
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
308316

309317

310318
@dataclass

graphrag/config/init_content.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,26 @@
145145
## See the config docs: https://microsoft.github.io/graphrag/config/yaml/#query
146146
147147
local_search:
148+
chat_model_id: {graphrag_config_defaults.local_search.chat_model_id}
149+
embedding_model_id: {graphrag_config_defaults.local_search.embedding_model_id}
148150
prompt: "prompts/local_search_system_prompt.txt"
149151
150152
global_search:
153+
chat_model_id: {graphrag_config_defaults.global_search.chat_model_id}
154+
embedding_model_id: {graphrag_config_defaults.global_search.embedding_model_id}
151155
map_prompt: "prompts/global_search_map_system_prompt.txt"
152156
reduce_prompt: "prompts/global_search_reduce_system_prompt.txt"
153157
knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt"
154158
155159
drift_search:
160+
chat_model_id: {graphrag_config_defaults.drift_search.chat_model_id}
161+
embedding_model_id: {graphrag_config_defaults.drift_search.embedding_model_id}
156162
prompt: "prompts/drift_search_system_prompt.txt"
157163
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
158164
159165
basic_search:
166+
chat_model_id: {graphrag_config_defaults.basic_search.chat_model_id}
167+
embedding_model_id: {graphrag_config_defaults.basic_search.embedding_model_id}
160168
prompt: "prompts/basic_search_system_prompt.txt"
161169
"""
162170

graphrag/config/models/basic_search_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ class BasicSearchConfig(BaseModel):
1515
description="The basic search prompt to use.",
1616
default=graphrag_config_defaults.basic_search.prompt,
1717
)
18+
chat_model_id: str = Field(
19+
description="The model ID to use for basic search.",
20+
default=graphrag_config_defaults.basic_search.chat_model_id,
21+
)
22+
embedding_model_id: str = Field(
23+
description="The model ID to use for text embeddings.",
24+
default=graphrag_config_defaults.basic_search.embedding_model_id,
25+
)
1826
text_unit_prop: float = Field(
1927
description="The text unit proportion.",
2028
default=graphrag_config_defaults.basic_search.text_unit_prop,

graphrag/config/models/drift_search_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ class DRIFTSearchConfig(BaseModel):
1919
description="The drift search reduce prompt to use.",
2020
default=graphrag_config_defaults.drift_search.reduce_prompt,
2121
)
22+
chat_model_id: str = Field(
23+
description="The model ID to use for drift search.",
24+
default=graphrag_config_defaults.drift_search.chat_model_id,
25+
)
26+
embedding_model_id: str = Field(
27+
description="The model ID to use for drift search.",
28+
default=graphrag_config_defaults.drift_search.embedding_model_id,
29+
)
2230
temperature: float = Field(
2331
description="The temperature to use for token generation.",
2432
default=graphrag_config_defaults.drift_search.temperature,

graphrag/config/models/global_search_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ class GlobalSearchConfig(BaseModel):
1919
description="The global search reducer to use.",
2020
default=graphrag_config_defaults.global_search.reduce_prompt,
2121
)
22+
chat_model_id: str = Field(
23+
description="The model ID to use for global search.",
24+
default=graphrag_config_defaults.global_search.chat_model_id,
25+
)
2226
knowledge_prompt: str | None = Field(
2327
description="The global search general prompt to use.",
2428
default=graphrag_config_defaults.global_search.knowledge_prompt,

graphrag/config/models/local_search_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ class LocalSearchConfig(BaseModel):
1515
description="The local search prompt to use.",
1616
default=graphrag_config_defaults.local_search.prompt,
1717
)
18+
chat_model_id: str = Field(
19+
description="The model ID to use for local search.",
20+
default=graphrag_config_defaults.local_search.chat_model_id,
21+
)
22+
embedding_model_id: str = Field(
23+
description="The model ID to use for text embeddings.",
24+
default=graphrag_config_defaults.local_search.embedding_model_id,
25+
)
1826
text_unit_prop: float = Field(
1927
description="The text unit proportion.",
2028
default=graphrag_config_defaults.local_search.text_unit_prop,

graphrag/index/operations/embed_text/strategies/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def _execute(
8888
) -> list[list[float]]:
8989
async def embed(chunk: list[str]):
9090
async with semaphore:
91-
chunk_embeddings = await model.embed(chunk)
91+
chunk_embeddings = await model.aembed_batch(chunk)
9292
result = np.array(chunk_embeddings)
9393
tick(1)
9494
return result

0 commit comments

Comments
 (0)