|
| 1 | +import logging |
| 2 | +import threading |
| 3 | +from unittest.mock import Mock, patch |
| 4 | + |
| 5 | +from template_langgraph.llms.azure_openais import AzureOpenAiWrapper, Settings |
| 6 | + |
| 7 | + |
| 8 | +class TestAzureOpenAiWrapper: |
| 9 | + """Test cases for AzureOpenAiWrapper authentication optimization.""" |
| 10 | + |
| 11 | + def setup_method(self): |
| 12 | + """Reset class-level variables before each test.""" |
| 13 | + AzureOpenAiWrapper._credentials.clear() |
| 14 | + AzureOpenAiWrapper._tokens.clear() |
| 15 | + |
| 16 | + def test_lazy_initialization_api_key(self, caplog): |
| 17 | + """Test that API key authentication uses lazy initialization.""" |
| 18 | + settings = Settings( |
| 19 | + azure_openai_use_microsoft_entra_id="false", |
| 20 | + azure_openai_api_key="dummy_key", |
| 21 | + azure_openai_endpoint="https://dummy.openai.azure.com/", |
| 22 | + ) |
| 23 | + |
| 24 | + with caplog.at_level(logging.INFO): |
| 25 | + # Creating instances should not trigger authentication |
| 26 | + wrapper1 = AzureOpenAiWrapper(settings) |
| 27 | + |
| 28 | + # No authentication logs yet |
| 29 | + assert "Using API key for authentication" not in caplog.text |
| 30 | + |
| 31 | + # Accessing models should trigger authentication |
| 32 | + try: |
| 33 | + _ = wrapper1.chat_model |
| 34 | + except Exception: |
| 35 | + pass # Expected due to dummy credentials |
| 36 | + |
| 37 | + # Should see authentication log only once per model access |
| 38 | + assert caplog.text.count("Using API key for authentication") == 1 |
| 39 | + |
| 40 | + # Second access should not trigger additional authentication |
| 41 | + try: |
| 42 | + _ = wrapper1.reasoning_model |
| 43 | + except Exception: |
| 44 | + pass |
| 45 | + |
| 46 | + # Should still be only one authentication log per model type |
| 47 | + assert caplog.text.count("Using API key for authentication") >= 1 |
| 48 | + |
| 49 | + @patch("template_langgraph.llms.azure_openais.DefaultAzureCredential") |
| 50 | + def test_singleton_credential_entra_id(self, mock_credential_class, caplog): |
| 51 | + """Test that Microsoft Entra ID credentials are reused across instances.""" |
| 52 | + # Mock the credential and token |
| 53 | + mock_credential = Mock() |
| 54 | + mock_token_obj = Mock() |
| 55 | + mock_token_obj.token = "mock_token_123" |
| 56 | + mock_credential.get_token.return_value = mock_token_obj |
| 57 | + mock_credential_class.return_value = mock_credential |
| 58 | + |
| 59 | + settings = Settings( |
| 60 | + azure_openai_use_microsoft_entra_id="true", |
| 61 | + azure_openai_endpoint="https://dummy.openai.azure.com/", |
| 62 | + ) |
| 63 | + |
| 64 | + with caplog.at_level(logging.INFO): |
| 65 | + # Create multiple instances |
| 66 | + wrapper1 = AzureOpenAiWrapper(settings) |
| 67 | + wrapper2 = AzureOpenAiWrapper(settings) |
| 68 | + |
| 69 | + # Access models to trigger authentication |
| 70 | + try: |
| 71 | + _ = wrapper1.chat_model |
| 72 | + _ = wrapper2.chat_model |
| 73 | + except Exception: |
| 74 | + pass # Expected due to mocking |
| 75 | + |
| 76 | + # Credential should be initialized only once |
| 77 | + assert mock_credential_class.call_count == 1 |
| 78 | + # Token should be requested only once |
| 79 | + assert mock_credential.get_token.call_count == 1 |
| 80 | + |
| 81 | + # Should see initialization logs only once |
| 82 | + assert caplog.text.count("Initializing Microsoft Entra ID authentication") == 1 |
| 83 | + assert caplog.text.count("Getting authentication token") == 1 |
| 84 | + |
| 85 | + @patch("template_langgraph.llms.azure_openais.DefaultAzureCredential") |
| 86 | + def test_thread_safety(self, mock_credential_class): |
| 87 | + """Test that authentication is thread-safe.""" |
| 88 | + mock_credential = Mock() |
| 89 | + mock_token_obj = Mock() |
| 90 | + mock_token_obj.token = "mock_token_123" |
| 91 | + mock_credential.get_token.return_value = mock_token_obj |
| 92 | + mock_credential_class.return_value = mock_credential |
| 93 | + |
| 94 | + settings = Settings( |
| 95 | + azure_openai_use_microsoft_entra_id="true", |
| 96 | + azure_openai_endpoint="https://dummy.openai.azure.com/", |
| 97 | + ) |
| 98 | + |
| 99 | + results = [] |
| 100 | + errors = [] |
| 101 | + |
| 102 | + def worker(): |
| 103 | + try: |
| 104 | + wrapper = AzureOpenAiWrapper(settings) |
| 105 | + token = wrapper._get_auth_token() |
| 106 | + results.append(token) |
| 107 | + except Exception as e: |
| 108 | + errors.append(e) |
| 109 | + |
| 110 | + # Create multiple threads that try to authenticate simultaneously |
| 111 | + threads = [] |
| 112 | + for _ in range(10): |
| 113 | + thread = threading.Thread(target=worker) |
| 114 | + threads.append(thread) |
| 115 | + |
| 116 | + # Start all threads |
| 117 | + for thread in threads: |
| 118 | + thread.start() |
| 119 | + |
| 120 | + # Wait for all threads to complete |
| 121 | + for thread in threads: |
| 122 | + thread.join() |
| 123 | + |
| 124 | + # Check results |
| 125 | + assert len(errors) == 0, f"Errors occurred: {errors}" |
| 126 | + assert len(results) == 10 |
| 127 | + assert all(token == "mock_token_123" for token in results) |
| 128 | + |
| 129 | + # Credential should be initialized only once despite multiple threads |
| 130 | + assert mock_credential_class.call_count == 1 |
| 131 | + assert mock_credential.get_token.call_count == 1 |
| 132 | + |
| 133 | + def test_different_settings_per_instance(self): |
| 134 | + """Test that different instances can have different settings.""" |
| 135 | + settings1 = Settings( |
| 136 | + azure_openai_use_microsoft_entra_id="false", |
| 137 | + azure_openai_api_key="key1", |
| 138 | + azure_openai_endpoint="https://endpoint1.openai.azure.com/", |
| 139 | + ) |
| 140 | + |
| 141 | + settings2 = Settings( |
| 142 | + azure_openai_use_microsoft_entra_id="false", |
| 143 | + azure_openai_api_key="key2", |
| 144 | + azure_openai_endpoint="https://endpoint2.openai.azure.com/", |
| 145 | + ) |
| 146 | + |
| 147 | + wrapper1 = AzureOpenAiWrapper(settings1) |
| 148 | + wrapper2 = AzureOpenAiWrapper(settings2) |
| 149 | + |
| 150 | + # Each instance should maintain its own settings |
| 151 | + assert wrapper1.settings.azure_openai_api_key == "key1" |
| 152 | + assert wrapper2.settings.azure_openai_api_key == "key2" |
| 153 | + assert wrapper1.settings.azure_openai_endpoint == "https://endpoint1.openai.azure.com/" |
| 154 | + assert wrapper2.settings.azure_openai_endpoint == "https://endpoint2.openai.azure.com/" |
| 155 | + |
| 156 | + def test_create_embedding_method_compatibility(self): |
| 157 | + """Test that the create_embedding method still works.""" |
| 158 | + settings = Settings( |
| 159 | + azure_openai_use_microsoft_entra_id="false", |
| 160 | + azure_openai_api_key="dummy_key", |
| 161 | + azure_openai_endpoint="https://dummy.openai.azure.com/", |
| 162 | + ) |
| 163 | + |
| 164 | + wrapper = AzureOpenAiWrapper(settings) |
| 165 | + |
| 166 | + # This should not raise an error about missing methods |
| 167 | + # (though it will fail due to dummy credentials) |
| 168 | + try: |
| 169 | + wrapper.create_embedding("test text") |
| 170 | + except Exception: |
| 171 | + pass # Expected due to dummy credentials |
| 172 | + |
| 173 | + # Verify the method exists and is callable |
| 174 | + assert hasattr(wrapper, "create_embedding") |
| 175 | + assert callable(getattr(wrapper, "create_embedding")) |
| 176 | + |
| 177 | + @patch("template_langgraph.llms.azure_openais.DefaultAzureCredential") |
| 178 | + def test_mixed_authentication_methods(self, mock_credential_class, caplog): |
| 179 | + """Test using both authentication methods in different instances.""" |
| 180 | + mock_credential = Mock() |
| 181 | + mock_token_obj = Mock() |
| 182 | + mock_token_obj.token = "mock_token_123" |
| 183 | + mock_credential.get_token.return_value = mock_token_obj |
| 184 | + mock_credential_class.return_value = mock_credential |
| 185 | + |
| 186 | + # API key settings |
| 187 | + api_settings = Settings( |
| 188 | + azure_openai_use_microsoft_entra_id="false", |
| 189 | + azure_openai_api_key="dummy_key", |
| 190 | + azure_openai_endpoint="https://dummy.openai.azure.com/", |
| 191 | + ) |
| 192 | + |
| 193 | + # Entra ID settings |
| 194 | + entra_settings = Settings( |
| 195 | + azure_openai_use_microsoft_entra_id="true", |
| 196 | + azure_openai_endpoint="https://dummy.openai.azure.com/", |
| 197 | + ) |
| 198 | + |
| 199 | + with caplog.at_level(logging.INFO): |
| 200 | + wrapper_api = AzureOpenAiWrapper(api_settings) |
| 201 | + wrapper_entra = AzureOpenAiWrapper(entra_settings) |
| 202 | + |
| 203 | + # Access models to trigger different authentication paths |
| 204 | + try: |
| 205 | + _ = wrapper_api.chat_model |
| 206 | + _ = wrapper_entra.chat_model |
| 207 | + except Exception: |
| 208 | + pass # Expected due to dummy/mock credentials |
| 209 | + |
| 210 | + # Should see both authentication methods being used |
| 211 | + assert "Using API key for authentication" in caplog.text |
| 212 | + assert "Initializing Microsoft Entra ID authentication" in caplog.text |
0 commit comments