Skip to content

Commit 94bd2bb

Browse files
authored
Require explicit azure auth settings when using AOI. (#1665)
* Require explicit azure auth settings when using AOI. - Must set LanguageModel.azure_auth_type to either "api_key" or "managed_identity" when using AOI. * Fix smoke tests * Use general auth_type property instead of azure_auth_type * Remove unused error type * Update validation * Update validation comment
1 parent d31750f commit 94bd2bb

File tree

10 files changed

+68
-24
lines changed

10 files changed

+68
-24
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Require explicit azure auth settings when using AOI."
4+
}

graphrag/config/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from graphrag.config.enums import (
99
AsyncType,
10+
AuthType,
1011
CacheType,
1112
ChunkStrategyType,
1213
InputFileType,
@@ -24,6 +25,7 @@
2425
ASYNC_MODE = AsyncType.Threaded
2526
ENCODING_MODEL = "cl100k_base"
2627
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
28+
AUTH_TYPE = AuthType.APIKey
2729
#
2830
# LLM Parameters
2931
#

graphrag/config/enums.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ def __repr__(self):
117117
return f'"{self.value}"'
118118

119119

120-
class AzureAuthType(str, Enum):
121-
"""AzureAuthType enum class definition."""
120+
class AuthType(str, Enum):
121+
"""AuthType enum class definition."""
122122

123123
APIKey = "api_key"
124-
ManagedIdentity = "managed_identity"
124+
AzureManagedIdentity = "azure_managed_identity"
125125

126126

127127
class AsyncType(str, Enum):

graphrag/config/init_content.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
{defs.DEFAULT_CHAT_MODEL_ID}:
1717
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
1818
type: {defs.LLM_TYPE.value} # or azure_openai_chat
19+
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
1920
model: {defs.LLM_MODEL}
2021
model_supports_json: true # recommended if this is available for your model.
2122
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
@@ -29,6 +30,7 @@
2930
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
3031
api_key: ${{GRAPHRAG_API_KEY}}
3132
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
33+
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
3234
model: {defs.EMBEDDING_MODEL}
3335
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
3436
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}

graphrag/config/models/language_model_config.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import BaseModel, Field, model_validator
88

99
import graphrag.config.defaults as defs
10-
from graphrag.config.enums import AsyncType, AzureAuthType, LLMType
10+
from graphrag.config.enums import AsyncType, AuthType, LLMType
1111
from graphrag.config.errors import (
1212
ApiKeyMissingError,
1313
AzureApiBaseMissingError,
@@ -40,27 +40,42 @@ def _validate_api_key(self) -> None:
4040
ApiKeyMissingError
4141
If the API key is missing and is required.
4242
"""
43-
if (
44-
self.type == LLMType.OpenAIEmbedding
45-
or self.type == LLMType.OpenAIChat
46-
or self.azure_auth_type == AzureAuthType.APIKey
47-
) and (self.api_key is None or self.api_key.strip() == ""):
43+
if self.auth_type == AuthType.APIKey and (
44+
self.api_key is None or self.api_key.strip() == ""
45+
):
4846
raise ApiKeyMissingError(
4947
self.type.value,
50-
self.azure_auth_type.value if self.azure_auth_type else None,
48+
self.auth_type.value,
5149
)
5250

53-
if (self.azure_auth_type == AzureAuthType.ManagedIdentity) and (
51+
if (self.auth_type == AuthType.AzureManagedIdentity) and (
5452
self.api_key is not None and self.api_key.strip() != ""
5553
):
5654
msg = "API Key should not be provided when using Azure Managed Identity. Please rerun `graphrag init` and remove the api_key when using Azure Managed Identity."
5755
raise ConflictingSettingsError(msg)
5856

59-
azure_auth_type: AzureAuthType | None = Field(
60-
description="The Azure authentication type to use when using AOI.",
61-
default=None,
57+
auth_type: AuthType = Field(
58+
description="The authentication type.",
59+
default=defs.AUTH_TYPE,
6260
)
6361

62+
def _validate_auth_type(self) -> None:
63+
"""Validate the authentication type.
64+
65+
auth_type must be api_key when using OpenAI and
66+
can be either api_key or azure_managed_identity when using AOI.
67+
68+
Raises
69+
------
70+
ConflictingSettingsError
71+
If the Azure authentication type conflicts with the model being used.
72+
"""
73+
if self.auth_type == AuthType.AzureManagedIdentity and (
74+
self.type == LLMType.OpenAIChat or self.type == LLMType.OpenAIEmbedding
75+
):
76+
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
77+
raise ConflictingSettingsError(msg)
78+
6479
type: LLMType = Field(description="The type of LLM model to use.")
6580
model: str = Field(description="The LLM model to use.")
6681
encoding_model: str = Field(description="The encoding model to use", default="")
@@ -233,6 +248,7 @@ def _validate_azure_settings(self) -> None:
233248

234249
@model_validator(mode="after")
235250
def _validate_model(self):
251+
self._validate_auth_type()
236252
self._validate_api_key()
237253
self._validate_azure_settings()
238254
self._validate_encoding_model()

graphrag/query/llm/get_client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
77

8-
from graphrag.config.enums import LLMType
8+
from graphrag.config.enums import AuthType, LLMType
99
from graphrag.config.models.graph_rag_config import GraphRagConfig
1010
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
1111
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
@@ -31,7 +31,8 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
3131
api_key=default_llm_settings.api_key,
3232
azure_ad_token_provider=(
3333
get_bearer_token_provider(DefaultAzureCredential(), audience)
34-
if is_azure_client and not default_llm_settings.api_key
34+
if is_azure_client
35+
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
3536
else None
3637
),
3738
api_base=default_llm_settings.api_base,
@@ -65,7 +66,8 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
6566
api_key=embeddings_llm_settings.api_key,
6667
azure_ad_token_provider=(
6768
get_bearer_token_provider(DefaultAzureCredential(), audience)
68-
if is_azure_client and not embeddings_llm_settings.api_key
69+
if is_azure_client
70+
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
6971
else None
7072
),
7173
api_base=embeddings_llm_settings.api_base,

tests/fixtures/min-csv/settings.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
models:
22
default_chat_model:
3+
azure_auth_type: api_key
34
type: ${GRAPHRAG_LLM_TYPE}
45
api_key: ${GRAPHRAG_API_KEY}
56
api_base: ${GRAPHRAG_API_BASE}
@@ -13,6 +14,7 @@ models:
1314
parallelization_stagger: 0.3
1415
async_mode: threaded
1516
default_embedding_model:
17+
azure_auth_type: api_key
1618
type: ${GRAPHRAG_EMBEDDING_TYPE}
1719
api_key: ${GRAPHRAG_API_KEY}
1820
api_base: ${GRAPHRAG_API_BASE}

tests/fixtures/text/settings.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
models:
22
default_chat_model:
3+
azure_auth_type: api_key
34
type: ${GRAPHRAG_LLM_TYPE}
45
api_key: ${GRAPHRAG_API_KEY}
56
api_base: ${GRAPHRAG_API_BASE}
@@ -13,6 +14,7 @@ models:
1314
parallelization_stagger: 0.3
1415
async_mode: threaded
1516
default_embedding_model:
17+
azure_auth_type: api_key
1618
type: ${GRAPHRAG_EMBEDDING_TYPE}
1719
api_key: ${GRAPHRAG_API_KEY}
1820
api_base: ${GRAPHRAG_API_BASE}

tests/unit/config/test_config.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import graphrag.config.defaults as defs
1212
from graphrag.config.create_graphrag_config import create_graphrag_config
13-
from graphrag.config.enums import AzureAuthType, LLMType
13+
from graphrag.config.enums import AuthType, LLMType
1414
from graphrag.config.load_config import load_config
1515
from tests.unit.config.utils import (
1616
DEFAULT_EMBEDDING_MODEL_CONFIG,
@@ -46,7 +46,7 @@ def test_missing_azure_api_key() -> None:
4646
model_config_missing_api_key = {
4747
defs.DEFAULT_CHAT_MODEL_ID: {
4848
"type": LLMType.AzureOpenAIChat,
49-
"azure_auth_type": AzureAuthType.APIKey,
49+
"auth_type": AuthType.APIKey,
5050
"model": defs.LLM_MODEL,
5151
"api_base": "some_api_base",
5252
"api_version": "some_api_version",
@@ -59,17 +59,31 @@ def test_missing_azure_api_key() -> None:
5959
create_graphrag_config({"models": model_config_missing_api_key})
6060

6161
# API Key not required for managed identity
62-
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["azure_auth_type"] = (
63-
AzureAuthType.ManagedIdentity
62+
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["auth_type"] = (
63+
AuthType.AzureManagedIdentity
6464
)
6565
create_graphrag_config({"models": model_config_missing_api_key})
6666

6767

68+
def test_conflicting_auth_type() -> None:
69+
model_config_invalid_auth_type = {
70+
defs.DEFAULT_CHAT_MODEL_ID: {
71+
"auth_type": AuthType.AzureManagedIdentity,
72+
"type": LLMType.OpenAIChat,
73+
"model": defs.LLM_MODEL,
74+
},
75+
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
76+
}
77+
78+
with pytest.raises(ValidationError):
79+
create_graphrag_config({"models": model_config_invalid_auth_type})
80+
81+
6882
def test_conflicting_azure_api_key() -> None:
6983
model_config_conflicting_api_key = {
7084
defs.DEFAULT_CHAT_MODEL_ID: {
7185
"type": LLMType.AzureOpenAIChat,
72-
"azure_auth_type": AzureAuthType.ManagedIdentity,
86+
"auth_type": AuthType.AzureManagedIdentity,
7387
"model": defs.LLM_MODEL,
7488
"api_base": "some_api_base",
7589
"api_version": "some_api_version",
@@ -85,7 +99,7 @@ def test_conflicting_azure_api_key() -> None:
8599

86100
base_azure_model_config = {
87101
"type": LLMType.AzureOpenAIChat,
88-
"azure_auth_type": AzureAuthType.ManagedIdentity,
102+
"auth_type": AuthType.AzureManagedIdentity,
89103
"model": defs.LLM_MODEL,
90104
"api_base": "some_api_base",
91105
"api_version": "some_api_version",

tests/unit/config/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def assert_language_model_configs(
245245
actual: LanguageModelConfig, expected: LanguageModelConfig
246246
) -> None:
247247
assert actual.api_key == expected.api_key
248-
assert actual.azure_auth_type == expected.azure_auth_type
248+
assert actual.auth_type == expected.auth_type
249249
assert actual.type == expected.type
250250
assert actual.model == expected.model
251251
assert actual.encoding_model == expected.encoding_model

0 commit comments

Comments
 (0)