Skip to content

Commit f14cda2

Browse files
authored
Improve default llm retry logic to be more optimized (#1701)
1 parent b8b949f commit f14cda2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+606
-567
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add dynamic retry logic."
4+
}

graphrag/api/prompt_tune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from pydantic import PositiveInt, validate_call
1515

16+
import graphrag.config.defaults as defs
1617
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1718
from graphrag.config.models.graph_rag_config import GraphRagConfig
1819
from graphrag.index.llm.load_llm import load_llm
@@ -95,8 +96,14 @@ async def generate_indexing_prompts(
9596
)
9697

9798
# Create LLM from config
98-
# TODO: Expose way to specify Prompt Tuning model ID through config
99+
# TODO: Expose a way to specify Prompt Tuning model ID through config
99100
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
101+
102+
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
103+
# to be made or fallback to a default value in the worst case
104+
if default_llm_settings.max_retries == -1:
105+
default_llm_settings.max_retries = min(len(doc_list), defs.LLM_MAX_RETRIES)
106+
100107
llm = load_llm(
101108
"prompt_tuning",
102109
default_llm_settings,

graphrag/config/defaults.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
2525
ASYNC_MODE = AsyncType.Threaded
2626
ENCODING_MODEL = "cl100k_base"
27-
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
27+
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
2828
AUTH_TYPE = AuthType.APIKey
2929
#
3030
# LLM Parameters
@@ -39,15 +39,12 @@
3939
LLM_REQUEST_TIMEOUT = 180.0
4040
LLM_TOKENS_PER_MINUTE = 50_000
4141
LLM_REQUESTS_PER_MINUTE = 1_000
42+
RETRY_STRATEGY = "native"
4243
LLM_MAX_RETRIES = 10
4344
LLM_MAX_RETRY_WAIT = 10.0
4445
LLM_PRESENCE_PENALTY = 0.0
45-
LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True
4646
LLM_CONCURRENT_REQUESTS = 25
4747

48-
PARALLELIZATION_STAGGER = 0.3
49-
PARALLELIZATION_NUM_THREADS = 50
50-
5148
#
5249
# Text embedding
5350
#

graphrag/config/init_content.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,41 @@
1414
1515
models:
1616
{defs.DEFAULT_CHAT_MODEL_ID}:
17-
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
1817
type: {defs.LLM_TYPE.value} # or azure_openai_chat
18+
# api_base: https://<instance>.openai.azure.com
19+
# api_version: 2024-05-01-preview
1920
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
20-
model: {defs.LLM_MODEL}
21-
model_supports_json: true # recommended if this is available for your model.
22-
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
23-
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
24-
async_mode: {defs.ASYNC_MODE.value} # or asyncio
21+
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
2522
# audience: "https://cognitiveservices.azure.com/.default"
26-
# api_base: https://<instance>.openai.azure.com
27-
# api_version: 2024-02-15-preview
2823
# organization: <organization_id>
24+
model: {defs.LLM_MODEL}
2925
# deployment_name: <azure_model_deployment_name>
26+
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
27+
model_supports_json: true # recommended if this is available for your model.
28+
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
29+
async_mode: {defs.ASYNC_MODE.value} # or asyncio
30+
retry_strategy: native
31+
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
32+
tokens_per_minute: 0 # set to 0 to disable rate limiting
33+
requests_per_minute: 0 # set to 0 to disable rate limiting
3034
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
31-
api_key: ${{GRAPHRAG_API_KEY}}
3235
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
33-
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
34-
model: {defs.EMBEDDING_MODEL}
35-
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
36-
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
37-
async_mode: {defs.ASYNC_MODE.value} # or asyncio
3836
# api_base: https://<instance>.openai.azure.com
39-
# api_version: 2024-02-15-preview
37+
# api_version: 2024-05-01-preview
38+
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
39+
api_key: ${{GRAPHRAG_API_KEY}}
4040
# audience: "https://cognitiveservices.azure.com/.default"
4141
# organization: <organization_id>
42+
model: {defs.EMBEDDING_MODEL}
4243
# deployment_name: <azure_model_deployment_name>
44+
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
45+
model_supports_json: true # recommended if this is available for your model.
46+
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
47+
async_mode: {defs.ASYNC_MODE.value} # or asyncio
48+
retry_strategy: native
49+
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
50+
tokens_per_minute: 0 # set to 0 to disable rate limiting
51+
requests_per_minute: 0 # set to 0 to disable rate limiting
4352
4453
vector_store:
4554
{defs.VECTOR_STORE_DEFAULT_ID}:

graphrag/config/models/community_reports_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def resolved_strategy(
4949
return self.strategy or {
5050
"type": CreateCommunityReportsStrategyType.graph_intelligence,
5151
"llm": model_config.model_dump(),
52-
"stagger": model_config.parallelization_stagger,
53-
"num_threads": model_config.parallelization_num_threads,
52+
"num_threads": model_config.concurrent_requests,
5453
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
5554
encoding="utf-8"
5655
)

graphrag/config/models/extract_claims_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def resolved_strategy(
4646
"""Get the resolved claim extraction strategy."""
4747
return self.strategy or {
4848
"llm": model_config.model_dump(),
49-
"stagger": model_config.parallelization_stagger,
50-
"num_threads": model_config.parallelization_num_threads,
49+
"num_threads": model_config.concurrent_requests,
5150
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
5251
encoding="utf-8"
5352
)

graphrag/config/models/extract_graph_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def resolved_strategy(
4747
return self.strategy or {
4848
"type": ExtractEntityStrategyType.graph_intelligence,
4949
"llm": model_config.model_dump(),
50-
"stagger": model_config.parallelization_stagger,
51-
"num_threads": model_config.parallelization_num_threads,
50+
"num_threads": model_config.concurrent_requests,
5251
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
5352
encoding="utf-8"
5453
)

graphrag/config/models/extract_graph_nlp_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class ExtractGraphNLPConfig(BaseModel):
6464
text_analyzer: TextAnalyzerConfig = Field(
6565
description="The text analyzer configuration.", default=TextAnalyzerConfig()
6666
)
67-
parallelization_num_threads: int = Field(
67+
concurrent_requests: int = Field(
6868
description="The number of threads to use for the extraction process.",
69-
default=defs.PARALLELIZATION_NUM_THREADS,
69+
default=defs.LLM_CONCURRENT_REQUESTS,
7070
)

graphrag/config/models/language_model_config.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _validate_api_key(self) -> None:
3131
API Key is required when using OpenAI API
3232
or when using Azure API with API Key authentication.
3333
For the time being, this check is extra verbose for clarity.
34-
It will also through an exception if an API Key is provided
34+
It will also raise an exception if an API Key is provided
3535
when one is not expected such as the case of using Azure
3636
Managed Identity.
3737
@@ -199,6 +199,10 @@ def _validate_deployment_name(self) -> None:
199199
description="The number of requests per minute to use for the LLM service.",
200200
default=defs.LLM_REQUESTS_PER_MINUTE,
201201
)
202+
retry_strategy: str = Field(
203+
description="The retry strategy to use for the LLM service.",
204+
default=defs.RETRY_STRATEGY,
205+
)
202206
max_retries: int = Field(
203207
description="The maximum number of retries to use for the LLM service.",
204208
default=defs.LLM_MAX_RETRIES,
@@ -207,25 +211,13 @@ def _validate_deployment_name(self) -> None:
207211
description="The maximum retry wait to use for the LLM service.",
208212
default=defs.LLM_MAX_RETRY_WAIT,
209213
)
210-
sleep_on_rate_limit_recommendation: bool = Field(
211-
description="Whether to sleep on rate limit recommendations.",
212-
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
213-
)
214214
concurrent_requests: int = Field(
215215
description="Whether to use concurrent requests for the LLM service.",
216216
default=defs.LLM_CONCURRENT_REQUESTS,
217217
)
218218
responses: list[str | BaseModel] | None = Field(
219219
default=None, description="Static responses to use in mock mode."
220220
)
221-
parallelization_stagger: float = Field(
222-
description="The stagger to use for the LLM service.",
223-
default=defs.PARALLELIZATION_STAGGER,
224-
)
225-
parallelization_num_threads: int = Field(
226-
description="The number of threads to use for the LLM service.",
227-
default=defs.PARALLELIZATION_NUM_THREADS,
228-
)
229221
async_mode: AsyncType = Field(
230222
description="The async mode to use.", default=defs.ASYNC_MODE
231223
)

graphrag/config/models/summarize_descriptions_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def resolved_strategy(
4040
return self.strategy or {
4141
"type": SummarizeStrategyType.graph_intelligence,
4242
"llm": model_config.model_dump(),
43-
"stagger": model_config.parallelization_stagger,
44-
"num_threads": model_config.parallelization_num_threads,
43+
"num_threads": model_config.concurrent_requests,
4544
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
4645
encoding="utf-8"
4746
)

0 commit comments

Comments
 (0)