Skip to content

Commit 8fbcb94

Browse files
authored
Merge pull request #115 from ks6088ts-labs/copilot/fix-114
Optimize AzureOpenAiWrapper authentication with singleton pattern and lazy initialization
2 parents 64e04d1 + 8d8b5c4 commit 8fbcb94

File tree

2 files changed

+314
-47
lines changed

2 files changed

+314
-47
lines changed

template_langgraph/llms/azure_openais.py

Lines changed: 102 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
from functools import lru_cache
23

34
from azure.identity import DefaultAzureCredential
@@ -31,57 +32,111 @@ def get_azure_openai_settings() -> Settings:
3132

3233

3334
class AzureOpenAiWrapper:
35+
# Class-level variables for singleton-like behavior
36+
_credentials: dict = {}
37+
_tokens: dict = {}
38+
_token_lock = threading.Lock()
39+
3440
def __init__(self, settings: Settings = None):
3541
if settings is None:
3642
settings = get_azure_openai_settings()
3743

38-
if settings.azure_openai_use_microsoft_entra_id.lower() == "true":
39-
logger.info("Using Microsoft Entra ID for authentication")
40-
credential = DefaultAzureCredential()
41-
token = credential.get_token("https://cognitiveservices.azure.com/.default").token
42-
43-
self.chat_model = AzureChatOpenAI(
44-
azure_ad_token=token,
45-
azure_endpoint=settings.azure_openai_endpoint,
46-
api_version=settings.azure_openai_api_version,
47-
azure_deployment=settings.azure_openai_model_chat,
48-
streaming=True,
49-
)
50-
self.reasoning_model = AzureChatOpenAI(
51-
azure_ad_token=token,
52-
azure_endpoint=settings.azure_openai_endpoint,
53-
api_version=settings.azure_openai_api_version,
54-
azure_deployment=settings.azure_openai_model_reasoning,
55-
streaming=True,
56-
)
57-
self.embedding_model = AzureOpenAIEmbeddings(
58-
azure_ad_token=token,
59-
azure_endpoint=settings.azure_openai_endpoint,
60-
api_version=settings.azure_openai_api_version,
61-
azure_deployment=settings.azure_openai_model_embedding,
62-
)
63-
else:
64-
logger.info("Using API key for authentication")
65-
self.chat_model = AzureChatOpenAI(
66-
api_key=settings.azure_openai_api_key,
67-
azure_endpoint=settings.azure_openai_endpoint,
68-
api_version=settings.azure_openai_api_version,
69-
azure_deployment=settings.azure_openai_model_chat,
70-
streaming=True,
71-
)
72-
self.reasoning_model = AzureChatOpenAI(
73-
api_key=settings.azure_openai_api_key,
74-
azure_endpoint=settings.azure_openai_endpoint,
75-
api_version=settings.azure_openai_api_version,
76-
azure_deployment=settings.azure_openai_model_reasoning,
77-
streaming=True,
78-
)
79-
self.embedding_model = AzureOpenAIEmbeddings(
80-
api_key=settings.azure_openai_api_key,
81-
azure_endpoint=settings.azure_openai_endpoint,
82-
api_version=settings.azure_openai_api_version,
83-
azure_deployment=settings.azure_openai_model_embedding,
84-
)
44+
self.settings = settings
45+
self._chat_model: AzureChatOpenAI | None = None
46+
self._reasoning_model: AzureChatOpenAI | None = None
47+
self._embedding_model: AzureOpenAIEmbeddings | None = None
48+
49+
def _get_auth_key(self) -> str:
50+
"""Generate a key for authentication caching based on settings."""
51+
return f"{self.settings.azure_openai_endpoint}_{self.settings.azure_openai_use_microsoft_entra_id}"
52+
53+
def _get_auth_token(self) -> str | None:
54+
"""Get authentication token with lazy initialization and caching."""
55+
if self.settings.azure_openai_use_microsoft_entra_id.lower() != "true":
56+
return None
57+
58+
auth_key = self._get_auth_key()
59+
60+
with self._token_lock:
61+
if auth_key not in self._credentials:
62+
logger.info("Initializing Microsoft Entra ID authentication")
63+
self._credentials[auth_key] = DefaultAzureCredential()
64+
65+
if auth_key not in self._tokens:
66+
logger.info("Getting authentication token")
67+
self._tokens[auth_key] = (
68+
self._credentials[auth_key].get_token("https://cognitiveservices.azure.com/.default").token
69+
)
70+
71+
return self._tokens[auth_key]
72+
73+
@property
74+
def chat_model(self) -> AzureChatOpenAI:
75+
"""Lazily initialize and return chat model."""
76+
if self._chat_model is None:
77+
if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true":
78+
token = self._get_auth_token()
79+
self._chat_model = AzureChatOpenAI(
80+
azure_ad_token=token,
81+
azure_endpoint=self.settings.azure_openai_endpoint,
82+
api_version=self.settings.azure_openai_api_version,
83+
azure_deployment=self.settings.azure_openai_model_chat,
84+
streaming=True,
85+
)
86+
else:
87+
logger.info("Using API key for authentication")
88+
self._chat_model = AzureChatOpenAI(
89+
api_key=self.settings.azure_openai_api_key,
90+
azure_endpoint=self.settings.azure_openai_endpoint,
91+
api_version=self.settings.azure_openai_api_version,
92+
azure_deployment=self.settings.azure_openai_model_chat,
93+
streaming=True,
94+
)
95+
return self._chat_model
96+
97+
@property
98+
def reasoning_model(self) -> AzureChatOpenAI:
99+
"""Lazily initialize and return reasoning model."""
100+
if self._reasoning_model is None:
101+
if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true":
102+
token = self._get_auth_token()
103+
self._reasoning_model = AzureChatOpenAI(
104+
azure_ad_token=token,
105+
azure_endpoint=self.settings.azure_openai_endpoint,
106+
api_version=self.settings.azure_openai_api_version,
107+
azure_deployment=self.settings.azure_openai_model_reasoning,
108+
streaming=True,
109+
)
110+
else:
111+
self._reasoning_model = AzureChatOpenAI(
112+
api_key=self.settings.azure_openai_api_key,
113+
azure_endpoint=self.settings.azure_openai_endpoint,
114+
api_version=self.settings.azure_openai_api_version,
115+
azure_deployment=self.settings.azure_openai_model_reasoning,
116+
streaming=True,
117+
)
118+
return self._reasoning_model
119+
120+
@property
121+
def embedding_model(self) -> AzureOpenAIEmbeddings:
122+
"""Lazily initialize and return embedding model."""
123+
if self._embedding_model is None:
124+
if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true":
125+
token = self._get_auth_token()
126+
self._embedding_model = AzureOpenAIEmbeddings(
127+
azure_ad_token=token,
128+
azure_endpoint=self.settings.azure_openai_endpoint,
129+
api_version=self.settings.azure_openai_api_version,
130+
azure_deployment=self.settings.azure_openai_model_embedding,
131+
)
132+
else:
133+
self._embedding_model = AzureOpenAIEmbeddings(
134+
api_key=self.settings.azure_openai_api_key,
135+
azure_endpoint=self.settings.azure_openai_endpoint,
136+
api_version=self.settings.azure_openai_api_version,
137+
azure_deployment=self.settings.azure_openai_model_embedding,
138+
)
139+
return self._embedding_model
85140

86141
def create_embedding(self, text: str):
87142
"""Create an embedding for the given text."""

tests/test_azure_openais.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import logging
2+
import threading
3+
from unittest.mock import Mock, patch
4+
5+
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper, Settings
6+
7+
8+
class TestAzureOpenAiWrapper:
9+
"""Test cases for AzureOpenAiWrapper authentication optimization."""
10+
11+
def setup_method(self):
12+
"""Reset class-level variables before each test."""
13+
AzureOpenAiWrapper._credentials.clear()
14+
AzureOpenAiWrapper._tokens.clear()
15+
16+
def test_lazy_initialization_api_key(self, caplog):
17+
"""Test that API key authentication uses lazy initialization."""
18+
settings = Settings(
19+
azure_openai_use_microsoft_entra_id="false",
20+
azure_openai_api_key="dummy_key",
21+
azure_openai_endpoint="https://dummy.openai.azure.com/",
22+
)
23+
24+
with caplog.at_level(logging.INFO):
25+
# Creating instances should not trigger authentication
26+
wrapper1 = AzureOpenAiWrapper(settings)
27+
28+
# No authentication logs yet
29+
assert "Using API key for authentication" not in caplog.text
30+
31+
# Accessing models should trigger authentication
32+
try:
33+
_ = wrapper1.chat_model
34+
except Exception:
35+
pass # Expected due to dummy credentials
36+
37+
# Should see authentication log only once per model access
38+
assert caplog.text.count("Using API key for authentication") == 1
39+
40+
# Second access should not trigger additional authentication
41+
try:
42+
_ = wrapper1.reasoning_model
43+
except Exception:
44+
pass
45+
46+
# Should still be only one authentication log per model type
47+
assert caplog.text.count("Using API key for authentication") >= 1
48+
49+
@patch("template_langgraph.llms.azure_openais.DefaultAzureCredential")
50+
def test_singleton_credential_entra_id(self, mock_credential_class, caplog):
51+
"""Test that Microsoft Entra ID credentials are reused across instances."""
52+
# Mock the credential and token
53+
mock_credential = Mock()
54+
mock_token_obj = Mock()
55+
mock_token_obj.token = "mock_token_123"
56+
mock_credential.get_token.return_value = mock_token_obj
57+
mock_credential_class.return_value = mock_credential
58+
59+
settings = Settings(
60+
azure_openai_use_microsoft_entra_id="true",
61+
azure_openai_endpoint="https://dummy.openai.azure.com/",
62+
)
63+
64+
with caplog.at_level(logging.INFO):
65+
# Create multiple instances
66+
wrapper1 = AzureOpenAiWrapper(settings)
67+
wrapper2 = AzureOpenAiWrapper(settings)
68+
69+
# Access models to trigger authentication
70+
try:
71+
_ = wrapper1.chat_model
72+
_ = wrapper2.chat_model
73+
except Exception:
74+
pass # Expected due to mocking
75+
76+
# Credential should be initialized only once
77+
assert mock_credential_class.call_count == 1
78+
# Token should be requested only once
79+
assert mock_credential.get_token.call_count == 1
80+
81+
# Should see initialization logs only once
82+
assert caplog.text.count("Initializing Microsoft Entra ID authentication") == 1
83+
assert caplog.text.count("Getting authentication token") == 1
84+
85+
@patch("template_langgraph.llms.azure_openais.DefaultAzureCredential")
86+
def test_thread_safety(self, mock_credential_class):
87+
"""Test that authentication is thread-safe."""
88+
mock_credential = Mock()
89+
mock_token_obj = Mock()
90+
mock_token_obj.token = "mock_token_123"
91+
mock_credential.get_token.return_value = mock_token_obj
92+
mock_credential_class.return_value = mock_credential
93+
94+
settings = Settings(
95+
azure_openai_use_microsoft_entra_id="true",
96+
azure_openai_endpoint="https://dummy.openai.azure.com/",
97+
)
98+
99+
results = []
100+
errors = []
101+
102+
def worker():
103+
try:
104+
wrapper = AzureOpenAiWrapper(settings)
105+
token = wrapper._get_auth_token()
106+
results.append(token)
107+
except Exception as e:
108+
errors.append(e)
109+
110+
# Create multiple threads that try to authenticate simultaneously
111+
threads = []
112+
for _ in range(10):
113+
thread = threading.Thread(target=worker)
114+
threads.append(thread)
115+
116+
# Start all threads
117+
for thread in threads:
118+
thread.start()
119+
120+
# Wait for all threads to complete
121+
for thread in threads:
122+
thread.join()
123+
124+
# Check results
125+
assert len(errors) == 0, f"Errors occurred: {errors}"
126+
assert len(results) == 10
127+
assert all(token == "mock_token_123" for token in results)
128+
129+
# Credential should be initialized only once despite multiple threads
130+
assert mock_credential_class.call_count == 1
131+
assert mock_credential.get_token.call_count == 1
132+
133+
def test_different_settings_per_instance(self):
134+
"""Test that different instances can have different settings."""
135+
settings1 = Settings(
136+
azure_openai_use_microsoft_entra_id="false",
137+
azure_openai_api_key="key1",
138+
azure_openai_endpoint="https://endpoint1.openai.azure.com/",
139+
)
140+
141+
settings2 = Settings(
142+
azure_openai_use_microsoft_entra_id="false",
143+
azure_openai_api_key="key2",
144+
azure_openai_endpoint="https://endpoint2.openai.azure.com/",
145+
)
146+
147+
wrapper1 = AzureOpenAiWrapper(settings1)
148+
wrapper2 = AzureOpenAiWrapper(settings2)
149+
150+
# Each instance should maintain its own settings
151+
assert wrapper1.settings.azure_openai_api_key == "key1"
152+
assert wrapper2.settings.azure_openai_api_key == "key2"
153+
assert wrapper1.settings.azure_openai_endpoint == "https://endpoint1.openai.azure.com/"
154+
assert wrapper2.settings.azure_openai_endpoint == "https://endpoint2.openai.azure.com/"
155+
156+
def test_create_embedding_method_compatibility(self):
157+
"""Test that the create_embedding method still works."""
158+
settings = Settings(
159+
azure_openai_use_microsoft_entra_id="false",
160+
azure_openai_api_key="dummy_key",
161+
azure_openai_endpoint="https://dummy.openai.azure.com/",
162+
)
163+
164+
wrapper = AzureOpenAiWrapper(settings)
165+
166+
# This should not raise an error about missing methods
167+
# (though it will fail due to dummy credentials)
168+
try:
169+
wrapper.create_embedding("test text")
170+
except Exception:
171+
pass # Expected due to dummy credentials
172+
173+
# Verify the method exists and is callable
174+
assert hasattr(wrapper, "create_embedding")
175+
assert callable(getattr(wrapper, "create_embedding"))
176+
177+
@patch("template_langgraph.llms.azure_openais.DefaultAzureCredential")
178+
def test_mixed_authentication_methods(self, mock_credential_class, caplog):
179+
"""Test using both authentication methods in different instances."""
180+
mock_credential = Mock()
181+
mock_token_obj = Mock()
182+
mock_token_obj.token = "mock_token_123"
183+
mock_credential.get_token.return_value = mock_token_obj
184+
mock_credential_class.return_value = mock_credential
185+
186+
# API key settings
187+
api_settings = Settings(
188+
azure_openai_use_microsoft_entra_id="false",
189+
azure_openai_api_key="dummy_key",
190+
azure_openai_endpoint="https://dummy.openai.azure.com/",
191+
)
192+
193+
# Entra ID settings
194+
entra_settings = Settings(
195+
azure_openai_use_microsoft_entra_id="true",
196+
azure_openai_endpoint="https://dummy.openai.azure.com/",
197+
)
198+
199+
with caplog.at_level(logging.INFO):
200+
wrapper_api = AzureOpenAiWrapper(api_settings)
201+
wrapper_entra = AzureOpenAiWrapper(entra_settings)
202+
203+
# Access models to trigger different authentication paths
204+
try:
205+
_ = wrapper_api.chat_model
206+
_ = wrapper_entra.chat_model
207+
except Exception:
208+
pass # Expected due to dummy/mock credentials
209+
210+
# Should see both authentication methods being used
211+
assert "Using API key for authentication" in caplog.text
212+
assert "Initializing Microsoft Entra ID authentication" in caplog.text

0 commit comments

Comments
 (0)