diff --git a/template_langgraph/llms/azure_openais.py b/template_langgraph/llms/azure_openais.py index 72cada7..e47f220 100644 --- a/template_langgraph/llms/azure_openais.py +++ b/template_langgraph/llms/azure_openais.py @@ -1,3 +1,4 @@ +import threading from functools import lru_cache from azure.identity import DefaultAzureCredential @@ -31,57 +32,111 @@ def get_azure_openai_settings() -> Settings: class AzureOpenAiWrapper: + # Class-level variables for singleton-like behavior + _credentials: dict = {} + _tokens: dict = {} + _token_lock = threading.Lock() + def __init__(self, settings: Settings = None): if settings is None: settings = get_azure_openai_settings() - if settings.azure_openai_use_microsoft_entra_id.lower() == "true": - logger.info("Using Microsoft Entra ID for authentication") - credential = DefaultAzureCredential() - token = credential.get_token("https://cognitiveservices.azure.com/.default").token - - self.chat_model = AzureChatOpenAI( - azure_ad_token=token, - azure_endpoint=settings.azure_openai_endpoint, - api_version=settings.azure_openai_api_version, - azure_deployment=settings.azure_openai_model_chat, - streaming=True, - ) - self.reasoning_model = AzureChatOpenAI( - azure_ad_token=token, - azure_endpoint=settings.azure_openai_endpoint, - api_version=settings.azure_openai_api_version, - azure_deployment=settings.azure_openai_model_reasoning, - streaming=True, - ) - self.embedding_model = AzureOpenAIEmbeddings( - azure_ad_token=token, - azure_endpoint=settings.azure_openai_endpoint, - api_version=settings.azure_openai_api_version, - azure_deployment=settings.azure_openai_model_embedding, - ) - else: - logger.info("Using API key for authentication") - self.chat_model = AzureChatOpenAI( - api_key=settings.azure_openai_api_key, - azure_endpoint=settings.azure_openai_endpoint, - api_version=settings.azure_openai_api_version, - azure_deployment=settings.azure_openai_model_chat, - streaming=True, - ) - self.reasoning_model = AzureChatOpenAI( - api_key=settings.azure_openai_api_key, - azure_endpoint=settings.azure_openai_endpoint, - api_version=settings.azure_openai_api_version, - azure_deployment=settings.azure_openai_model_reasoning, - streaming=True, - ) - self.embedding_model = AzureOpenAIEmbeddings( - api_key=settings.azure_openai_api_key, - azure_endpoint=settings.azure_openai_endpoint, - api_version=settings.azure_openai_api_version, - azure_deployment=settings.azure_openai_model_embedding, - ) + self.settings = settings + self._chat_model: AzureChatOpenAI | None = None + self._reasoning_model: AzureChatOpenAI | None = None + self._embedding_model: AzureOpenAIEmbeddings | None = None + + def _get_auth_key(self) -> str: + """Generate a key for authentication caching based on settings.""" + return f"{self.settings.azure_openai_endpoint}_{self.settings.azure_openai_use_microsoft_entra_id}" + + def _get_auth_token(self) -> str | None: + """Get authentication token with lazy initialization and caching.""" + if self.settings.azure_openai_use_microsoft_entra_id.lower() != "true": + return None + + auth_key = self._get_auth_key() + + with self._token_lock: + if auth_key not in self._credentials: + logger.info("Initializing Microsoft Entra ID authentication") + self._credentials[auth_key] = DefaultAzureCredential() + + if auth_key not in self._tokens: + logger.info("Getting authentication token") + self._tokens[auth_key] = ( + self._credentials[auth_key].get_token("https://cognitiveservices.azure.com/.default").token + ) + + return self._tokens[auth_key] + + @property + def chat_model(self) -> AzureChatOpenAI: + """Lazily initialize and return chat model.""" + if self._chat_model is None: + if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true": + token = self._get_auth_token() + self._chat_model = AzureChatOpenAI( + azure_ad_token=token, + azure_endpoint=self.settings.azure_openai_endpoint, + api_version=self.settings.azure_openai_api_version, + azure_deployment=self.settings.azure_openai_model_chat, + streaming=True, + ) + else: + logger.info("Using API key for authentication") + self._chat_model = AzureChatOpenAI( + api_key=self.settings.azure_openai_api_key, + azure_endpoint=self.settings.azure_openai_endpoint, + api_version=self.settings.azure_openai_api_version, + azure_deployment=self.settings.azure_openai_model_chat, + streaming=True, + ) + return self._chat_model + + @property + def reasoning_model(self) -> AzureChatOpenAI: + """Lazily initialize and return reasoning model.""" + if self._reasoning_model is None: + if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true": + token = self._get_auth_token() + self._reasoning_model = AzureChatOpenAI( + azure_ad_token=token, + azure_endpoint=self.settings.azure_openai_endpoint, + api_version=self.settings.azure_openai_api_version, + azure_deployment=self.settings.azure_openai_model_reasoning, + streaming=True, + ) + else: + self._reasoning_model = AzureChatOpenAI( + api_key=self.settings.azure_openai_api_key, + azure_endpoint=self.settings.azure_openai_endpoint, + api_version=self.settings.azure_openai_api_version, + azure_deployment=self.settings.azure_openai_model_reasoning, + streaming=True, + ) + return self._reasoning_model + + @property + def embedding_model(self) -> AzureOpenAIEmbeddings: + """Lazily initialize and return embedding model.""" + if self._embedding_model is None: + if self.settings.azure_openai_use_microsoft_entra_id.lower() == "true": + token = self._get_auth_token() + self._embedding_model = AzureOpenAIEmbeddings( + azure_ad_token=token, + azure_endpoint=self.settings.azure_openai_endpoint, + api_version=self.settings.azure_openai_api_version, + azure_deployment=self.settings.azure_openai_model_embedding, + ) + else: + self._embedding_model = AzureOpenAIEmbeddings( + api_key=self.settings.azure_openai_api_key, + azure_endpoint=self.settings.azure_openai_endpoint, + api_version=self.settings.azure_openai_api_version, + azure_deployment=self.settings.azure_openai_model_embedding, + ) + return self._embedding_model def create_embedding(self, text: str): """Create an embedding for the given text.""" diff --git a/tests/test_azure_openais.py b/tests/test_azure_openais.py new file mode 100644 index 0000000..cd1a10c --- /dev/null +++ b/tests/test_azure_openais.py @@ -0,0 +1,212 @@ +import logging +import threading +from unittest.mock import Mock, patch + +from template_langgraph.llms.azure_openais import AzureOpenAiWrapper, Settings + + +class TestAzureOpenAiWrapper: + """Test cases for AzureOpenAiWrapper authentication optimization.""" + + def setup_method(self): + """Reset class-level variables before each test.""" + AzureOpenAiWrapper._credentials.clear() + AzureOpenAiWrapper._tokens.clear() + + def test_lazy_initialization_api_key(self, caplog): + """Test that API key authentication uses lazy initialization.""" + settings = Settings( + azure_openai_use_microsoft_entra_id="false", + azure_openai_api_key="dummy_key", + azure_openai_endpoint="https://dummy.openai.azure.com/", + ) + + with caplog.at_level(logging.INFO): + # Creating instances should not trigger authentication + wrapper1 = AzureOpenAiWrapper(settings) + + # No authentication logs yet + assert "Using API key for authentication" not in caplog.text + + # Accessing models should trigger authentication + try: + _ = wrapper1.chat_model + except Exception: + pass # Expected due to dummy credentials + + # Should see authentication log only once per model access + assert caplog.text.count("Using API key for authentication") == 1 + + # Second access should not trigger additional authentication + try: + _ = wrapper1.reasoning_model + except Exception: + pass + + # Should still be only one authentication log per model type + assert caplog.text.count("Using API key for authentication") >= 1 + + @patch("template_langgraph.llms.azure_openais.DefaultAzureCredential") + def test_singleton_credential_entra_id(self, mock_credential_class, caplog): + """Test that Microsoft Entra ID credentials are reused across instances.""" + # Mock the credential and token + mock_credential = Mock() + mock_token_obj = Mock() + mock_token_obj.token = "mock_token_123" + mock_credential.get_token.return_value = mock_token_obj + mock_credential_class.return_value = mock_credential + + settings = Settings( + azure_openai_use_microsoft_entra_id="true", + azure_openai_endpoint="https://dummy.openai.azure.com/", + ) + + with caplog.at_level(logging.INFO): + # Create multiple instances + wrapper1 = AzureOpenAiWrapper(settings) + wrapper2 = AzureOpenAiWrapper(settings) + + # Access models to trigger authentication + try: + _ = wrapper1.chat_model + _ = wrapper2.chat_model + except Exception: + pass # Expected due to mocking + + # Credential should be initialized only once + assert mock_credential_class.call_count == 1 + # Token should be requested only once + assert mock_credential.get_token.call_count == 1 + + # Should see initialization logs only once + assert caplog.text.count("Initializing Microsoft Entra ID authentication") == 1 + assert caplog.text.count("Getting authentication token") == 1 + + @patch("template_langgraph.llms.azure_openais.DefaultAzureCredential") + def test_thread_safety(self, mock_credential_class): + """Test that authentication is thread-safe.""" + mock_credential = Mock() + mock_token_obj = Mock() + mock_token_obj.token = "mock_token_123" + mock_credential.get_token.return_value = mock_token_obj + mock_credential_class.return_value = mock_credential + + settings = Settings( + azure_openai_use_microsoft_entra_id="true", + azure_openai_endpoint="https://dummy.openai.azure.com/", + ) + + results = [] + errors = [] + + def worker(): + try: + wrapper = AzureOpenAiWrapper(settings) + token = wrapper._get_auth_token() + results.append(token) + except Exception as e: + errors.append(e) + + # Create multiple threads that try to authenticate simultaneously + threads = [] + for _ in range(10): + thread = threading.Thread(target=worker) + threads.append(thread) + + # Start all threads + for thread in threads: + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Check results + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == 10 + assert all(token == "mock_token_123" for token in results) + + # Credential should be initialized only once despite multiple threads + assert mock_credential_class.call_count == 1 + assert mock_credential.get_token.call_count == 1 + + def test_different_settings_per_instance(self): + """Test that different instances can have different settings.""" + settings1 = Settings( + azure_openai_use_microsoft_entra_id="false", + azure_openai_api_key="key1", + azure_openai_endpoint="https://endpoint1.openai.azure.com/", + ) + + settings2 = Settings( + azure_openai_use_microsoft_entra_id="false", + azure_openai_api_key="key2", + azure_openai_endpoint="https://endpoint2.openai.azure.com/", + ) + + wrapper1 = AzureOpenAiWrapper(settings1) + wrapper2 = AzureOpenAiWrapper(settings2) + + # Each instance should maintain its own settings + assert wrapper1.settings.azure_openai_api_key == "key1" + assert wrapper2.settings.azure_openai_api_key == "key2" + assert wrapper1.settings.azure_openai_endpoint == "https://endpoint1.openai.azure.com/" + assert wrapper2.settings.azure_openai_endpoint == "https://endpoint2.openai.azure.com/" + + def test_create_embedding_method_compatibility(self): + """Test that the create_embedding method still works.""" + settings = Settings( + azure_openai_use_microsoft_entra_id="false", + azure_openai_api_key="dummy_key", + azure_openai_endpoint="https://dummy.openai.azure.com/", + ) + + wrapper = AzureOpenAiWrapper(settings) + + # This should not raise an error about missing methods + # (though it will fail due to dummy credentials) + try: + wrapper.create_embedding("test text") + except Exception: + pass # Expected due to dummy credentials + + # Verify the method exists and is callable + assert hasattr(wrapper, "create_embedding") + assert callable(getattr(wrapper, "create_embedding")) + + @patch("template_langgraph.llms.azure_openais.DefaultAzureCredential") + def test_mixed_authentication_methods(self, mock_credential_class, caplog): + """Test using both authentication methods in different instances.""" + mock_credential = Mock() + mock_token_obj = Mock() + mock_token_obj.token = "mock_token_123" + mock_credential.get_token.return_value = mock_token_obj + mock_credential_class.return_value = mock_credential + + # API key settings + api_settings = Settings( + azure_openai_use_microsoft_entra_id="false", + azure_openai_api_key="dummy_key", + azure_openai_endpoint="https://dummy.openai.azure.com/", + ) + + # Entra ID settings + entra_settings = Settings( + azure_openai_use_microsoft_entra_id="true", + azure_openai_endpoint="https://dummy.openai.azure.com/", + ) + + with caplog.at_level(logging.INFO): + wrapper_api = AzureOpenAiWrapper(api_settings) + wrapper_entra = AzureOpenAiWrapper(entra_settings) + + # Access models to trigger different authentication paths + try: + _ = wrapper_api.chat_model + _ = wrapper_entra.chat_model + except Exception: + pass # Expected due to dummy/mock credentials + + # Should see both authentication methods being used + assert "Using API key for authentication" in caplog.text + assert "Initializing Microsoft Entra ID authentication" in caplog.text