1+ import logging
2+ import threading
3+ import time
4+ from unittest .mock import Mock , patch
5+ import pytest
6+
7+ from template_langgraph .llms .azure_openais import AzureOpenAiWrapper , Settings
8+
9+
10+ class TestAzureOpenAiWrapper :
11+ """Test cases for AzureOpenAiWrapper authentication optimization."""
12+
13+ def setup_method (self ):
14+ """Reset class-level variables before each test."""
15+ AzureOpenAiWrapper ._credentials .clear ()
16+ AzureOpenAiWrapper ._tokens .clear ()
17+
18+ def test_lazy_initialization_api_key (self , caplog ):
19+ """Test that API key authentication uses lazy initialization."""
20+ settings = Settings (
21+ azure_openai_use_microsoft_entra_id = "false" ,
22+ azure_openai_api_key = "dummy_key" ,
23+ azure_openai_endpoint = "https://dummy.openai.azure.com/" ,
24+ )
25+
26+ with caplog .at_level (logging .INFO ):
27+ # Creating instances should not trigger authentication
28+ wrapper1 = AzureOpenAiWrapper (settings )
29+ wrapper2 = AzureOpenAiWrapper (settings )
30+
31+ # No authentication logs yet
32+ assert "Using API key for authentication" not in caplog .text
33+
34+ # Accessing models should trigger authentication
35+ try :
36+ _ = wrapper1 .chat_model
37+ except Exception :
38+ pass # Expected due to dummy credentials
39+
40+ # Should see authentication log only once per model access
41+ assert caplog .text .count ("Using API key for authentication" ) == 1
42+
43+ # Second access should not trigger additional authentication
44+ try :
45+ _ = wrapper1 .reasoning_model
46+ except Exception :
47+ pass
48+
49+ # Should still be only one authentication log per model type
50+ assert caplog .text .count ("Using API key for authentication" ) >= 1
51+
52+ @patch ('template_langgraph.llms.azure_openais.DefaultAzureCredential' )
53+ def test_singleton_credential_entra_id (self , mock_credential_class , caplog ):
54+ """Test that Microsoft Entra ID credentials are reused across instances."""
55+ # Mock the credential and token
56+ mock_credential = Mock ()
57+ mock_token_obj = Mock ()
58+ mock_token_obj .token = "mock_token_123"
59+ mock_credential .get_token .return_value = mock_token_obj
60+ mock_credential_class .return_value = mock_credential
61+
62+ settings = Settings (
63+ azure_openai_use_microsoft_entra_id = "true" ,
64+ azure_openai_endpoint = "https://dummy.openai.azure.com/" ,
65+ )
66+
67+ with caplog .at_level (logging .INFO ):
68+ # Create multiple instances
69+ wrapper1 = AzureOpenAiWrapper (settings )
70+ wrapper2 = AzureOpenAiWrapper (settings )
71+
72+ # Access models to trigger authentication
73+ try :
74+ _ = wrapper1 .chat_model
75+ _ = wrapper2 .chat_model
76+ except Exception :
77+ pass # Expected due to mocking
78+
79+ # Credential should be initialized only once
80+ assert mock_credential_class .call_count == 1
81+ # Token should be requested only once
82+ assert mock_credential .get_token .call_count == 1
83+
84+ # Should see initialization logs only once
85+ assert caplog .text .count ("Initializing Microsoft Entra ID authentication" ) == 1
86+ assert caplog .text .count ("Getting authentication token" ) == 1
87+
88+ @patch ('template_langgraph.llms.azure_openais.DefaultAzureCredential' )
89+ def test_thread_safety (self , mock_credential_class ):
90+ """Test that authentication is thread-safe."""
91+ mock_credential = Mock ()
92+ mock_token_obj = Mock ()
93+ mock_token_obj .token = "mock_token_123"
94+ mock_credential .get_token .return_value = mock_token_obj
95+ mock_credential_class .return_value = mock_credential
96+
97+ settings = Settings (
98+ azure_openai_use_microsoft_entra_id = "true" ,
99+ azure_openai_endpoint = "https://dummy.openai.azure.com/" ,
100+ )
101+
102+ results = []
103+ errors = []
104+
105+ def worker ():
106+ try :
107+ wrapper = AzureOpenAiWrapper (settings )
108+ token = wrapper ._get_auth_token ()
109+ results .append (token )
110+ except Exception as e :
111+ errors .append (e )
112+
113+ # Create multiple threads that try to authenticate simultaneously
114+ threads = []
115+ for _ in range (10 ):
116+ thread = threading .Thread (target = worker )
117+ threads .append (thread )
118+
119+ # Start all threads
120+ for thread in threads :
121+ thread .start ()
122+
123+ # Wait for all threads to complete
124+ for thread in threads :
125+ thread .join ()
126+
127+ # Check results
128+ assert len (errors ) == 0 , f"Errors occurred: { errors } "
129+ assert len (results ) == 10
130+ assert all (token == "mock_token_123" for token in results )
131+
132+ # Credential should be initialized only once despite multiple threads
133+ assert mock_credential_class .call_count == 1
134+ assert mock_credential .get_token .call_count == 1
135+
136+ def test_different_settings_per_instance (self ):
137+ """Test that different instances can have different settings."""
138+ settings1 = Settings (
139+ azure_openai_use_microsoft_entra_id = "false" ,
140+ azure_openai_api_key = "key1" ,
141+ azure_openai_endpoint = "https://endpoint1.openai.azure.com/" ,
142+ )
143+
144+ settings2 = Settings (
145+ azure_openai_use_microsoft_entra_id = "false" ,
146+ azure_openai_api_key = "key2" ,
147+ azure_openai_endpoint = "https://endpoint2.openai.azure.com/" ,
148+ )
149+
150+ wrapper1 = AzureOpenAiWrapper (settings1 )
151+ wrapper2 = AzureOpenAiWrapper (settings2 )
152+
153+ # Each instance should maintain its own settings
154+ assert wrapper1 .settings .azure_openai_api_key == "key1"
155+ assert wrapper2 .settings .azure_openai_api_key == "key2"
156+ assert wrapper1 .settings .azure_openai_endpoint == "https://endpoint1.openai.azure.com/"
157+ assert wrapper2 .settings .azure_openai_endpoint == "https://endpoint2.openai.azure.com/"
158+
159+ def test_create_embedding_method_compatibility (self ):
160+ """Test that the create_embedding method still works."""
161+ settings = Settings (
162+ azure_openai_use_microsoft_entra_id = "false" ,
163+ azure_openai_api_key = "dummy_key" ,
164+ azure_openai_endpoint = "https://dummy.openai.azure.com/" ,
165+ )
166+
167+ wrapper = AzureOpenAiWrapper (settings )
168+
169+ # This should not raise an error about missing methods
170+ # (though it will fail due to dummy credentials)
171+ try :
172+ wrapper .create_embedding ("test text" )
173+ except Exception :
174+ pass # Expected due to dummy credentials
175+
176+ # Verify the method exists and is callable
177+ assert hasattr (wrapper , 'create_embedding' )
178+ assert callable (getattr (wrapper , 'create_embedding' ))
179+
180+ @patch ('template_langgraph.llms.azure_openais.DefaultAzureCredential' )
181+ def test_mixed_authentication_methods (self , mock_credential_class , caplog ):
182+ """Test using both authentication methods in different instances."""
183+ mock_credential = Mock ()
184+ mock_token_obj = Mock ()
185+ mock_token_obj .token = "mock_token_123"
186+ mock_credential .get_token .return_value = mock_token_obj
187+ mock_credential_class .return_value = mock_credential
188+
189+ # API key settings
190+ api_settings = Settings (
191+ azure_openai_use_microsoft_entra_id = "false" ,
192+ azure_openai_api_key = "dummy_key" ,
193+ azure_openai_endpoint = "https://dummy.openai.azure.com/" ,
194+ )
195+
196+ # Entra ID settings
197+ entra_settings = Settings (
198+ azure_openai_use_microsoft_entra_id = "true" ,
199+ azure_openai_endpoint = "https://dummy.openai.azure.com/" ,
200+ )
201+
202+ with caplog .at_level (logging .INFO ):
203+ wrapper_api = AzureOpenAiWrapper (api_settings )
204+ wrapper_entra = AzureOpenAiWrapper (entra_settings )
205+
206+ # Access models to trigger different authentication paths
207+ try :
208+ _ = wrapper_api .chat_model
209+ _ = wrapper_entra .chat_model
210+ except Exception :
211+ pass # Expected due to dummy/mock credentials
212+
213+ # Should see both authentication methods being used
214+ assert "Using API key for authentication" in caplog .text
215+ assert "Initializing Microsoft Entra ID authentication" in caplog .text
0 commit comments