Skip to content

Commit 9e5e5dd

Browse files
Copilotks6088ts
andcommitted
Implement authentication optimization for AzureOpenAiWrapper
Co-authored-by: ks6088ts <[email protected]>
1 parent 152d337 commit 9e5e5dd

File tree

2 files changed

+315
-46
lines changed

2 files changed

+315
-46
lines changed

template_langgraph/llms/azure_openais.py

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from functools import lru_cache
2+
import threading
3+
from typing import Optional
24

35
from azure.identity import DefaultAzureCredential
46
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
@@ -31,57 +33,109 @@ def get_azure_openai_settings() -> Settings:
3133

3234

3335
class AzureOpenAiWrapper:
36+
# Class-level variables for singleton-like behavior
37+
_credentials: dict = {}
38+
_tokens: dict = {}
39+
_token_lock = threading.Lock()
40+
3441
def __init__(self, settings: Settings = None):
3542
if settings is None:
3643
settings = get_azure_openai_settings()
44+
45+
self.settings = settings
46+
self._chat_model: Optional[AzureChatOpenAI] = None
47+
self._reasoning_model: Optional[AzureChatOpenAI] = None
48+
self._embedding_model: Optional[AzureOpenAIEmbeddings] = None
3749

38-
if settings.azure_openai_use_microsoft_entra_id.lower() == "true":
39-
logger.info("Using Microsoft Entra ID for authentication")
40-
credential = DefaultAzureCredential()
41-
token = credential.get_token("https://cognitiveservices.azure.com/.default").token
50+
def _get_auth_key(self) -> str:
51+
"""Generate a key for authentication caching based on settings."""
52+
return f"{self.settings.azure_openai_endpoint}_{self.settings.azure_openai_use_microsoft_entra_id}"
4253

43-
self.chat_model = AzureChatOpenAI(
44-
azure_ad_token=token,
45-
azure_endpoint=settings.azure_openai_endpoint,
46-
api_version=settings.azure_openai_api_version,
47-
azure_deployment=settings.azure_openai_model_chat,
48-
streaming=True,
49-
)
50-
self.reasoning_model = AzureChatOpenAI(
51-
azure_ad_token=token,
52-
azure_endpoint=settings.azure_openai_endpoint,
53-
api_version=settings.azure_openai_api_version,
54-
azure_deployment=settings.azure_openai_model_reasoning,
55-
streaming=True,
56-
)
57-
self.embedding_model = AzureOpenAIEmbeddings(
58-
azure_ad_token=token,
59-
azure_endpoint=settings.azure_openai_endpoint,
60-
api_version=settings.azure_openai_api_version,
61-
azure_deployment=settings.azure_openai_model_embedding,
62-
)
63-
else:
64-
logger.info("Using API key for authentication")
65-
self.chat_model = AzureChatOpenAI(
66-
api_key=settings.azure_openai_api_key,
67-
azure_endpoint=settings.azure_openai_endpoint,
68-
api_version=settings.azure_openai_api_version,
69-
azure_deployment=settings.azure_openai_model_chat,
70-
streaming=True,
71-
)
72-
self.reasoning_model = AzureChatOpenAI(
73-
api_key=settings.azure_openai_api_key,
74-
azure_endpoint=settings.azure_openai_endpoint,
75-
api_version=settings.azure_openai_api_version,
76-
azure_deployment=settings.azure_openai_model_reasoning,
77-
streaming=True,
78-
)
79-
self.embedding_model = AzureOpenAIEmbeddings(
80-
api_key=settings.azure_openai_api_key,
81-
azure_endpoint=settings.azure_openai_endpoint,
82-
api_version=settings.azure_openai_api_version,
83-
azure_deployment=settings.azure_openai_model_embedding,
84-
)
54+
def _get_auth_token(self) -> Optional[str]:
55+
"""Get authentication token with lazy initialization and caching."""
56+
if self.settings.azure_openai_use_microsoft_entra_id.lower() != "true":
57+
return None
58+
59+
auth_key = self._get_auth_key()
60+
61+
with self._token_lock:
62+
if auth_key not in self._credentials:
63+
logger.info("Initializing Microsoft Entra ID authentication")
64+
self._credentials[auth_key] = DefaultAzureCredential()
65+
66+
if auth_key not in self._tokens:
67+
logger.info("Getting authentication token")
68+
self._tokens[auth_key] = self._credentials[auth_key].get_token("https://cognitiveservices.azure.com/.default").token
69+
70+
return self._tokens[auth_key]
71+
72+
@property
73+
def chat_model(self) -> AzureChatOpenAI:
74+
"""Lazily initialize and return chat model."""
75+
if self._chat_model is None:
76+
if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true":
77+
token = self._get_auth_token()
78+
self._chat_model = AzureChatOpenAI(
79+
azure_ad_token=token,
80+
azure_endpoint=self.settings.azure_openai_endpoint,
81+
api_version=self.settings.azure_openai_api_version,
82+
azure_deployment=self.settings.azure_openai_model_chat,
83+
streaming=True,
84+
)
85+
else:
86+
logger.info("Using API key for authentication")
87+
self._chat_model = AzureChatOpenAI(
88+
api_key=self.settings.azure_openai_api_key,
89+
azure_endpoint=self.settings.azure_openai_endpoint,
90+
api_version=self.settings.azure_openai_api_version,
91+
azure_deployment=self.settings.azure_openai_model_chat,
92+
streaming=True,
93+
)
94+
return self._chat_model
95+
96+
@property
97+
def reasoning_model(self) -> AzureChatOpenAI:
98+
"""Lazily initialize and return reasoning model."""
99+
if self._reasoning_model is None:
100+
if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true":
101+
token = self._get_auth_token()
102+
self._reasoning_model = AzureChatOpenAI(
103+
azure_ad_token=token,
104+
azure_endpoint=self.settings.azure_openai_endpoint,
105+
api_version=self.settings.azure_openai_api_version,
106+
azure_deployment=self.settings.azure_openai_model_reasoning,
107+
streaming=True,
108+
)
109+
else:
110+
self._reasoning_model = AzureChatOpenAI(
111+
api_key=self.settings.azure_openai_api_key,
112+
azure_endpoint=self.settings.azure_openai_endpoint,
113+
api_version=self.settings.azure_openai_api_version,
114+
azure_deployment=self.settings.azure_openai_model_reasoning,
115+
streaming=True,
116+
)
117+
return self._reasoning_model
118+
119+
@property
120+
def embedding_model(self) -> AzureOpenAIEmbeddings:
121+
"""Lazily initialize and return embedding model."""
122+
if self._embedding_model is None:
123+
if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true":
124+
token = self._get_auth_token()
125+
self._embedding_model = AzureOpenAIEmbeddings(
126+
azure_ad_token=token,
127+
azure_endpoint=self.settings.azure_openai_endpoint,
128+
api_version=self.settings.azure_openai_api_version,
129+
azure_deployment=self.settings.azure_openai_model_embedding,
130+
)
131+
else:
132+
self._embedding_model = AzureOpenAIEmbeddings(
133+
api_key=self.settings.azure_openai_api_key,
134+
azure_endpoint=self.settings.azure_openai_endpoint,
135+
api_version=self.settings.azure_openai_api_version,
136+
azure_deployment=self.settings.azure_openai_model_embedding,
137+
)
138+
return self._embedding_model
85139

86140
def create_embedding(self, text: str):
87141
"""Create an embedding for the given text."""

tests/test_azure_openais.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import logging
2+
import threading
3+
import time
4+
from unittest.mock import Mock, patch
5+
import pytest
6+
7+
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper, Settings
8+
9+
10+
class TestAzureOpenAiWrapper:
11+
"""Test cases for AzureOpenAiWrapper authentication optimization."""
12+
13+
def setup_method(self):
14+
"""Reset class-level variables before each test."""
15+
AzureOpenAiWrapper._credentials.clear()
16+
AzureOpenAiWrapper._tokens.clear()
17+
18+
def test_lazy_initialization_api_key(self, caplog):
19+
"""Test that API key authentication uses lazy initialization."""
20+
settings = Settings(
21+
azure_openai_use_microsoft_entra_id="false",
22+
azure_openai_api_key="dummy_key",
23+
azure_openai_endpoint="https://dummy.openai.azure.com/",
24+
)
25+
26+
with caplog.at_level(logging.INFO):
27+
# Creating instances should not trigger authentication
28+
wrapper1 = AzureOpenAiWrapper(settings)
29+
wrapper2 = AzureOpenAiWrapper(settings)
30+
31+
# No authentication logs yet
32+
assert "Using API key for authentication" not in caplog.text
33+
34+
# Accessing models should trigger authentication
35+
try:
36+
_ = wrapper1.chat_model
37+
except Exception:
38+
pass # Expected due to dummy credentials
39+
40+
# Should see authentication log only once per model access
41+
assert caplog.text.count("Using API key for authentication") == 1
42+
43+
# Second access should not trigger additional authentication
44+
try:
45+
_ = wrapper1.reasoning_model
46+
except Exception:
47+
pass
48+
49+
# Should still be only one authentication log per model type
50+
assert caplog.text.count("Using API key for authentication") >= 1
51+
52+
@patch('template_langgraph.llms.azure_openais.DefaultAzureCredential')
53+
def test_singleton_credential_entra_id(self, mock_credential_class, caplog):
54+
"""Test that Microsoft Entra ID credentials are reused across instances."""
55+
# Mock the credential and token
56+
mock_credential = Mock()
57+
mock_token_obj = Mock()
58+
mock_token_obj.token = "mock_token_123"
59+
mock_credential.get_token.return_value = mock_token_obj
60+
mock_credential_class.return_value = mock_credential
61+
62+
settings = Settings(
63+
azure_openai_use_microsoft_entra_id="true",
64+
azure_openai_endpoint="https://dummy.openai.azure.com/",
65+
)
66+
67+
with caplog.at_level(logging.INFO):
68+
# Create multiple instances
69+
wrapper1 = AzureOpenAiWrapper(settings)
70+
wrapper2 = AzureOpenAiWrapper(settings)
71+
72+
# Access models to trigger authentication
73+
try:
74+
_ = wrapper1.chat_model
75+
_ = wrapper2.chat_model
76+
except Exception:
77+
pass # Expected due to mocking
78+
79+
# Credential should be initialized only once
80+
assert mock_credential_class.call_count == 1
81+
# Token should be requested only once
82+
assert mock_credential.get_token.call_count == 1
83+
84+
# Should see initialization logs only once
85+
assert caplog.text.count("Initializing Microsoft Entra ID authentication") == 1
86+
assert caplog.text.count("Getting authentication token") == 1
87+
88+
@patch('template_langgraph.llms.azure_openais.DefaultAzureCredential')
89+
def test_thread_safety(self, mock_credential_class):
90+
"""Test that authentication is thread-safe."""
91+
mock_credential = Mock()
92+
mock_token_obj = Mock()
93+
mock_token_obj.token = "mock_token_123"
94+
mock_credential.get_token.return_value = mock_token_obj
95+
mock_credential_class.return_value = mock_credential
96+
97+
settings = Settings(
98+
azure_openai_use_microsoft_entra_id="true",
99+
azure_openai_endpoint="https://dummy.openai.azure.com/",
100+
)
101+
102+
results = []
103+
errors = []
104+
105+
def worker():
106+
try:
107+
wrapper = AzureOpenAiWrapper(settings)
108+
token = wrapper._get_auth_token()
109+
results.append(token)
110+
except Exception as e:
111+
errors.append(e)
112+
113+
# Create multiple threads that try to authenticate simultaneously
114+
threads = []
115+
for _ in range(10):
116+
thread = threading.Thread(target=worker)
117+
threads.append(thread)
118+
119+
# Start all threads
120+
for thread in threads:
121+
thread.start()
122+
123+
# Wait for all threads to complete
124+
for thread in threads:
125+
thread.join()
126+
127+
# Check results
128+
assert len(errors) == 0, f"Errors occurred: {errors}"
129+
assert len(results) == 10
130+
assert all(token == "mock_token_123" for token in results)
131+
132+
# Credential should be initialized only once despite multiple threads
133+
assert mock_credential_class.call_count == 1
134+
assert mock_credential.get_token.call_count == 1
135+
136+
def test_different_settings_per_instance(self):
137+
"""Test that different instances can have different settings."""
138+
settings1 = Settings(
139+
azure_openai_use_microsoft_entra_id="false",
140+
azure_openai_api_key="key1",
141+
azure_openai_endpoint="https://endpoint1.openai.azure.com/",
142+
)
143+
144+
settings2 = Settings(
145+
azure_openai_use_microsoft_entra_id="false",
146+
azure_openai_api_key="key2",
147+
azure_openai_endpoint="https://endpoint2.openai.azure.com/",
148+
)
149+
150+
wrapper1 = AzureOpenAiWrapper(settings1)
151+
wrapper2 = AzureOpenAiWrapper(settings2)
152+
153+
# Each instance should maintain its own settings
154+
assert wrapper1.settings.azure_openai_api_key == "key1"
155+
assert wrapper2.settings.azure_openai_api_key == "key2"
156+
assert wrapper1.settings.azure_openai_endpoint == "https://endpoint1.openai.azure.com/"
157+
assert wrapper2.settings.azure_openai_endpoint == "https://endpoint2.openai.azure.com/"
158+
159+
def test_create_embedding_method_compatibility(self):
160+
"""Test that the create_embedding method still works."""
161+
settings = Settings(
162+
azure_openai_use_microsoft_entra_id="false",
163+
azure_openai_api_key="dummy_key",
164+
azure_openai_endpoint="https://dummy.openai.azure.com/",
165+
)
166+
167+
wrapper = AzureOpenAiWrapper(settings)
168+
169+
# This should not raise an error about missing methods
170+
# (though it will fail due to dummy credentials)
171+
try:
172+
wrapper.create_embedding("test text")
173+
except Exception:
174+
pass # Expected due to dummy credentials
175+
176+
# Verify the method exists and is callable
177+
assert hasattr(wrapper, 'create_embedding')
178+
assert callable(getattr(wrapper, 'create_embedding'))
179+
180+
@patch('template_langgraph.llms.azure_openais.DefaultAzureCredential')
181+
def test_mixed_authentication_methods(self, mock_credential_class, caplog):
182+
"""Test using both authentication methods in different instances."""
183+
mock_credential = Mock()
184+
mock_token_obj = Mock()
185+
mock_token_obj.token = "mock_token_123"
186+
mock_credential.get_token.return_value = mock_token_obj
187+
mock_credential_class.return_value = mock_credential
188+
189+
# API key settings
190+
api_settings = Settings(
191+
azure_openai_use_microsoft_entra_id="false",
192+
azure_openai_api_key="dummy_key",
193+
azure_openai_endpoint="https://dummy.openai.azure.com/",
194+
)
195+
196+
# Entra ID settings
197+
entra_settings = Settings(
198+
azure_openai_use_microsoft_entra_id="true",
199+
azure_openai_endpoint="https://dummy.openai.azure.com/",
200+
)
201+
202+
with caplog.at_level(logging.INFO):
203+
wrapper_api = AzureOpenAiWrapper(api_settings)
204+
wrapper_entra = AzureOpenAiWrapper(entra_settings)
205+
206+
# Access models to trigger different authentication paths
207+
try:
208+
_ = wrapper_api.chat_model
209+
_ = wrapper_entra.chat_model
210+
except Exception:
211+
pass # Expected due to dummy/mock credentials
212+
213+
# Should see both authentication methods being used
214+
assert "Using API key for authentication" in caplog.text
215+
assert "Initializing Microsoft Entra ID authentication" in caplog.text

0 commit comments

Comments
 (0)