Skip to content

Commit 75bda68

Browse files
committed
Require explicit azure auth settings when using AOI.
- Must set LanguageModel.azure_auth_type to either "api_key" or "managed_identity" when using AOI.
1 parent eeee84e commit 75bda68

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Require explicit azure auth settings when using AOI."
4+
}

graphrag/config/errors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ def __init__(self, llm_type: str, azure_auth_type: str | None = None) -> None:
1515
super().__init__(msg)
1616

1717

18+
class AzureAuthTypeMissingError(ValueError):
19+
"""Azure Auth type missing error."""
20+
21+
def __init__(self, llm_type: str) -> None:
22+
"""Init method definition."""
23+
msg = f"azure_auth_type is required for {llm_type}. Please rerun `graphrag init` and set the azure_auth_type."
24+
super().__init__(msg)
25+
26+
1827
class AzureApiBaseMissingError(ValueError):
1928
"""Azure API Base missing error."""
2029

graphrag/config/models/language_model_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ApiKeyMissingError,
1313
AzureApiBaseMissingError,
1414
AzureApiVersionMissingError,
15+
AzureAuthTypeMissingError,
1516
AzureDeploymentNameMissingError,
1617
ConflictingSettingsError,
1718
)
@@ -61,6 +62,23 @@ def _validate_api_key(self) -> None:
6162
default=None,
6263
)
6364

65+
def _validate_azure_auth_type(self) -> None:
66+
"""Validate the Azure authentication type.
67+
68+
azure_auth_type is required when using Azure OpenAI
69+
and explicitly defines the authentication type to use.
70+
71+
Raises
72+
------
73+
AzureAuthTypeMissingError
74+
If the Azure authentication type is missing when required.
75+
"""
76+
if (
77+
self.type == LLMType.AzureOpenAIChat
78+
or self.type == LLMType.AzureOpenAIEmbedding
79+
) and self.azure_auth_type is None:
80+
raise AzureAuthTypeMissingError(self.type.value)
81+
6482
type: LLMType = Field(description="The type of LLM model to use.")
6583
model: str = Field(description="The LLM model to use.")
6684
encoding_model: str = Field(description="The encoding model to use", default="")
@@ -220,13 +238,16 @@ def _validate_azure_settings(self) -> None:
220238
221239
Raises
222240
------
241+
AzureAuthTypeMissingError
242+
If the Azure authentication type is missing and is required.
223243
AzureApiBaseMissingError
224244
If the API base is missing and is required.
225245
AzureApiVersionMissingError
226246
If the API version is missing and is required.
227247
AzureDeploymentNameMissingError
228248
If the deployment name is missing and is required.
229249
"""
250+
self._validate_azure_auth_type()
230251
self._validate_api_base()
231252
self._validate_api_version()
232253
self._validate_deployment_name()

graphrag/query/llm/get_client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
77

8-
from graphrag.config.enums import LLMType
8+
from graphrag.config.enums import AzureAuthType, LLMType
99
from graphrag.config.models.graph_rag_config import GraphRagConfig
1010
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
1111
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
@@ -31,7 +31,8 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
3131
api_key=default_llm_settings.api_key,
3232
azure_ad_token_provider=(
3333
get_bearer_token_provider(DefaultAzureCredential(), audience)
34-
if is_azure_client and not default_llm_settings.api_key
34+
if is_azure_client
35+
and default_llm_settings.azure_auth_type == AzureAuthType.ManagedIdentity
3536
else None
3637
),
3738
api_base=default_llm_settings.api_base,
@@ -65,7 +66,8 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
6566
api_key=embeddings_llm_settings.api_key,
6667
azure_ad_token_provider=(
6768
get_bearer_token_provider(DefaultAzureCredential(), audience)
68-
if is_azure_client and not embeddings_llm_settings.api_key
69+
if is_azure_client
70+
and embeddings_llm_settings.azure_auth_type == AzureAuthType.ManagedIdentity
6971
else None
7072
),
7173
api_base=embeddings_llm_settings.api_base,

0 commit comments

Comments
 (0)