diff --git a/langfuse/_client/client.py b/langfuse/_client/client.py index 55a9667e3..869e6264d 100644 --- a/langfuse/_client/client.py +++ b/langfuse/_client/client.py @@ -3199,29 +3199,15 @@ def get_prompt( langfuse_logger.warning( f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}" ) - - fallback_client_args: Dict[str, Any] = { - "name": name, - "prompt": fallback, - "type": type, - "version": version or 0, - "config": {}, - "labels": [label] if label else [], - "tags": [], - } - - if type == "text": - return TextPromptClient( - prompt=Prompt_Text(**fallback_client_args), - is_fallback=True, - ) - - if type == "chat": - return ChatPromptClient( - prompt=Prompt_Chat(**fallback_client_args), - is_fallback=True, - ) - + fallback_prompt = self._create_fallback_prompt_client( + name=name, + prompt=fallback, + type=type, + version=version, + label=label, + ) + if fallback_prompt is not None: + return fallback_prompt raise e if cached_prompt.is_expired(): @@ -3259,6 +3245,101 @@ def refresh_task() -> None: return cached_prompt.value + async def aget_prompt( + self, + name: str, + *, + version: Optional[int] = None, + label: Optional[str] = None, + type: Literal["chat", "text"] = "text", + cache_ttl_seconds: Optional[int] = None, + fallback: Union[Optional[List[ChatMessageDict]], Optional[str]] = None, + max_retries: Optional[int] = None, + fetch_timeout_seconds: Optional[int] = None, + ) -> PromptClient: + """Async variant of get_prompt with identical semantics.""" + if self._resources is None: + raise Error( + "SDK is not correctly initialized. Check the init logs for more details." + ) + if version is not None and label is not None: + raise ValueError("Cannot specify both version and label at the same time.") + if not name: + raise ValueError("Prompt name cannot be empty.") + + cache_key = PromptCache.generate_cache_key(name, version=version, label=label) + bounded_max_retries = self._get_bounded_max_retries( + max_retries, default_max_retries=2, max_retries_upper_bound=4 + ) + + langfuse_logger.debug(f"Getting prompt '{cache_key}' (async)") + cached_prompt = self._resources.prompt_cache.get(cache_key) + + if cached_prompt is None or cache_ttl_seconds == 0: + langfuse_logger.debug( + f"Prompt '{cache_key}' not found in cache or caching disabled (async)." + ) + try: + return await self._fetch_prompt_and_update_cache_async( + name, + version=version, + label=label, + ttl_seconds=cache_ttl_seconds, + max_retries=bounded_max_retries, + fetch_timeout_seconds=fetch_timeout_seconds, + ) + except Exception as e: + if fallback: + langfuse_logger.warning( + f"Returning fallback prompt for '{cache_key}' due to async fetch error: {e}" + ) + fallback_prompt = self._create_fallback_prompt_client( + name=name, + prompt=fallback, + type=type, + version=version, + label=label, + ) + if fallback_prompt is not None: + return fallback_prompt + raise e + + if cached_prompt.is_expired(): + langfuse_logger.debug( + f"Stale prompt '{cache_key}' found in cache (async). Refresh scheduled." + ) + try: + + async def refresh_coroutine() -> None: + await self._fetch_prompt_and_update_cache_async( + name, + version=version, + label=label, + ttl_seconds=cache_ttl_seconds, + max_retries=bounded_max_retries, + fetch_timeout_seconds=fetch_timeout_seconds, + ) + + def refresh_task() -> None: + run_async_safely(refresh_coroutine()) + + self._resources.prompt_cache.add_refresh_prompt_task( + cache_key, + refresh_task, + ) + langfuse_logger.debug( + f"Returning stale prompt '{cache_key}' from cache (async)." + ) + return cached_prompt.value + + except Exception as e: + langfuse_logger.warning( + f"Error when scheduling async refresh for cached prompt '{cache_key}', returning cached version. Error: {e}" + ) + return cached_prompt.value + + return cached_prompt.value + def _fetch_prompt_and_update_cache( self, name: str, @@ -3308,6 +3389,93 @@ def fetch_prompts() -> Any: ) raise e + async def _fetch_prompt_and_update_cache_async( + self, + name: str, + *, + version: Optional[int] = None, + label: Optional[str] = None, + ttl_seconds: Optional[int] = None, + max_retries: int, + fetch_timeout_seconds: Optional[int], + ) -> PromptClient: + cache_key = PromptCache.generate_cache_key(name, version=version, label=label) + langfuse_logger.debug( + f"Fetching prompt '{cache_key}' from server asynchronously..." + ) + + try: + + @backoff.on_exception( + backoff.constant, Exception, max_tries=max_retries + 1, logger=None + ) + async def fetch_prompts() -> Any: + return await self.async_api.prompts.get( + self._url_encode(name), + version=version, + label=label, + request_options={ + "timeout_in_seconds": fetch_timeout_seconds, + } + if fetch_timeout_seconds is not None + else None, + ) + + prompt_response = await fetch_prompts() + + prompt: PromptClient + if prompt_response.type == "chat": + prompt = ChatPromptClient(prompt_response) + else: + prompt = TextPromptClient(prompt_response) + + if self._resources is not None: + self._resources.prompt_cache.set(cache_key, prompt, ttl_seconds) + + return prompt + + except Exception as e: + langfuse_logger.error( + f"Error while asynchronously fetching prompt '{cache_key}': {str(e)}" + ) + raise e + + def _create_fallback_prompt_client( + self, + *, + name: str, + prompt: Union[Optional[List[ChatMessageDict]], Optional[str]], + type: Literal["chat", "text"], + version: Optional[int], + label: Optional[str], + ) -> Optional[PromptClient]: + if prompt is None: + return None + + fallback_client_args: Dict[str, Any] = { + "name": name, + "prompt": prompt, + "type": type, + "version": version or 0, + "config": {}, + "labels": [label] if label else [], + "tags": [], + } + + if type == "text": + return TextPromptClient( + prompt=Prompt_Text(**fallback_client_args), + is_fallback=True, + ) + + if type == "chat": + return ChatPromptClient( + prompt=Prompt_Chat(**fallback_client_args), + is_fallback=True, + ) + + return None + def _get_bounded_max_retries( self, max_retries: Optional[int], diff --git a/tests/test_prompt.py b/tests/test_prompt.py index e5346debf..db17601d7 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,5 +1,6 @@ +import asyncio from time import sleep -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import openai import pytest @@ -7,6 +8,7 @@ from langfuse._client.client import Langfuse from langfuse._utils.prompt_cache import ( DEFAULT_PROMPT_CACHE_TTL_SECONDS, + PromptCache, PromptCacheItem, ) from langfuse.api.resources.prompts import Prompt_Chat, Prompt_Text @@ -681,6 +683,9 @@ def test_prompt_end_to_end(): def langfuse(): langfuse_instance = Langfuse() langfuse_instance.api = Mock() + langfuse_instance.async_api = Mock() + langfuse_instance._resources = Mock() + langfuse_instance._resources.prompt_cache = PromptCache() return langfuse_instance @@ -712,6 +717,92 @@ def test_get_fresh_prompt(langfuse): assert result == TextPromptClient(prompt) +def test_async_get_fresh_prompt(langfuse): + langfuse._resources.prompt_cache.clear() + prompt_name = "test_async_get_fresh_prompt" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + type="text", + labels=[], + config={}, + tags=[], + ) + + langfuse.async_api.prompts = Mock() + langfuse.async_api.prompts.get = AsyncMock(return_value=prompt) + + result = asyncio.run(langfuse.aget_prompt(prompt_name, fallback="fallback")) + + langfuse.async_api.prompts.get.assert_awaited_once_with( + prompt_name, + version=None, + label=None, + request_options=None, + ) + assert result == TextPromptClient(prompt) + + cache_key = langfuse._resources.prompt_cache.generate_cache_key( + prompt_name, version=None, label=None + ) + cached_item = langfuse._resources.prompt_cache.get(cache_key) + assert cached_item is not None + assert cached_item.value == result + + +def test_async_get_prompt_uses_cache_without_fetch(langfuse): + langfuse._resources.prompt_cache.clear() + prompt_name = "test_async_get_prompt_uses_cache_without_fetch" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Cached prompt", + type="text", + labels=[], + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + cache_key = langfuse._resources.prompt_cache.generate_cache_key( + prompt_name, version=None, label=None + ) + langfuse._resources.prompt_cache.set(cache_key, prompt_client, ttl_seconds=60) + + langfuse.async_api.prompts = Mock() + langfuse.async_api.prompts.get = AsyncMock() + + result = asyncio.run(langfuse.aget_prompt(prompt_name)) + + assert result == prompt_client + langfuse.async_api.prompts.get.assert_not_called() + + +def test_async_get_prompt_returns_fallback_on_failure(langfuse): + langfuse._resources.prompt_cache.clear() + prompt_name = "test_async_get_prompt_returns_fallback_on_failure" + + langfuse.async_api.prompts = Mock() + langfuse.async_api.prompts.get = AsyncMock(side_effect=Exception("boom")) + + result = asyncio.run( + langfuse.aget_prompt( + prompt_name, fallback="fallback text", max_retries=0 + ) + ) + + assert isinstance(result, TextPromptClient) + assert result.is_fallback is True + assert result.prompt == "fallback text" + + cache_key = langfuse._resources.prompt_cache.generate_cache_key( + prompt_name, version=None, label=None + ) + cached_item = langfuse._resources.prompt_cache.get(cache_key) + assert cached_item is None + + # Should throw an error if prompt name is unspecified def test_throw_if_name_unspecified(langfuse): prompt_name = ""