Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 191 additions & 23 deletions langfuse/_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3199,29 +3199,15 @@ def get_prompt(
langfuse_logger.warning(
f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}"
)

fallback_client_args: Dict[str, Any] = {
"name": name,
"prompt": fallback,
"type": type,
"version": version or 0,
"config": {},
"labels": [label] if label else [],
"tags": [],
}

if type == "text":
return TextPromptClient(
prompt=Prompt_Text(**fallback_client_args),
is_fallback=True,
)

if type == "chat":
return ChatPromptClient(
prompt=Prompt_Chat(**fallback_client_args),
is_fallback=True,
)

fallback_prompt = self._create_fallback_prompt_client(
name=name,
prompt=fallback,
type=type,
version=version,
label=label,
)
if fallback_prompt is not None:
return fallback_prompt
raise e

if cached_prompt.is_expired():
Expand Down Expand Up @@ -3259,6 +3245,101 @@ def refresh_task() -> None:

return cached_prompt.value

async def aget_prompt(
self,
name: str,
*,
version: Optional[int] = None,
label: Optional[str] = None,
type: Literal["chat", "text"] = "text",
cache_ttl_seconds: Optional[int] = None,
fallback: Union[Optional[List[ChatMessageDict]], Optional[str]] = None,
max_retries: Optional[int] = None,
fetch_timeout_seconds: Optional[int] = None,
) -> PromptClient:
"""Async variant of get_prompt with identical semantics."""
if self._resources is None:
raise Error(
"SDK is not correctly initialized. Check the init logs for more details."
)
if version is not None and label is not None:
raise ValueError("Cannot specify both version and label at the same time.")
if not name:
raise ValueError("Prompt name cannot be empty.")

cache_key = PromptCache.generate_cache_key(name, version=version, label=label)
bounded_max_retries = self._get_bounded_max_retries(
max_retries, default_max_retries=2, max_retries_upper_bound=4
)

langfuse_logger.debug(f"Getting prompt '{cache_key}' (async)")
cached_prompt = self._resources.prompt_cache.get(cache_key)

if cached_prompt is None or cache_ttl_seconds == 0:
langfuse_logger.debug(
f"Prompt '{cache_key}' not found in cache or caching disabled (async)."
)
try:
return await self._fetch_prompt_and_update_cache_async(
name,
version=version,
label=label,
ttl_seconds=cache_ttl_seconds,
max_retries=bounded_max_retries,
fetch_timeout_seconds=fetch_timeout_seconds,
)
except Exception as e:
if fallback:
langfuse_logger.warning(
f"Returning fallback prompt for '{cache_key}' due to async fetch error: {e}"
)
fallback_prompt = self._create_fallback_prompt_client(
name=name,
prompt=fallback,
type=type,
version=version,
label=label,
)
if fallback_prompt is not None:
return fallback_prompt
raise e

if cached_prompt.is_expired():
langfuse_logger.debug(
f"Stale prompt '{cache_key}' found in cache (async). Refresh scheduled."
)
try:

async def refresh_coroutine() -> None:
await self._fetch_prompt_and_update_cache_async(
name,
version=version,
label=label,
ttl_seconds=cache_ttl_seconds,
max_retries=bounded_max_retries,
fetch_timeout_seconds=fetch_timeout_seconds,
)

def refresh_task() -> None:
run_async_safely(refresh_coroutine())

self._resources.prompt_cache.add_refresh_prompt_task(
cache_key,
refresh_task,
)
langfuse_logger.debug(
f"Returning stale prompt '{cache_key}' from cache (async)."
)
return cached_prompt.value

except Exception as e:
langfuse_logger.warning(
f"Error when scheduling async refresh for cached prompt '{cache_key}', returning cached version. Error: {e}"
)
return cached_prompt.value

return cached_prompt.value

def _fetch_prompt_and_update_cache(
self,
name: str,
Expand Down Expand Up @@ -3308,6 +3389,93 @@ def fetch_prompts() -> Any:
)
raise e

async def _fetch_prompt_and_update_cache_async(
self,
name: str,
*,
version: Optional[int] = None,
label: Optional[str] = None,
ttl_seconds: Optional[int] = None,
max_retries: int,
fetch_timeout_seconds: Optional[int],
) -> PromptClient:
cache_key = PromptCache.generate_cache_key(name, version=version, label=label)
langfuse_logger.debug(
f"Fetching prompt '{cache_key}' from server asynchronously..."
)

try:

@backoff.on_exception(
backoff.constant, Exception, max_tries=max_retries + 1, logger=None
)
async def fetch_prompts() -> Any:
return await self.async_api.prompts.get(
self._url_encode(name),
version=version,
label=label,
request_options={
"timeout_in_seconds": fetch_timeout_seconds,
}
if fetch_timeout_seconds is not None
else None,
)

prompt_response = await fetch_prompts()

prompt: PromptClient
if prompt_response.type == "chat":
prompt = ChatPromptClient(prompt_response)
else:
prompt = TextPromptClient(prompt_response)

if self._resources is not None:
self._resources.prompt_cache.set(cache_key, prompt, ttl_seconds)

return prompt

except Exception as e:
langfuse_logger.error(
f"Error while asynchronously fetching prompt '{cache_key}': {str(e)}"
)
raise e

def _create_fallback_prompt_client(
self,
*,
name: str,
prompt: Union[Optional[List[ChatMessageDict]], Optional[str]],
type: Literal["chat", "text"],
version: Optional[int],
label: Optional[str],
) -> Optional[PromptClient]:
if prompt is None:
return None

fallback_client_args: Dict[str, Any] = {
"name": name,
"prompt": prompt,
"type": type,
"version": version or 0,
"config": {},
"labels": [label] if label else [],
"tags": [],
}

if type == "text":
return TextPromptClient(
prompt=Prompt_Text(**fallback_client_args),
is_fallback=True,
)

if type == "chat":
return ChatPromptClient(
prompt=Prompt_Chat(**fallback_client_args),
is_fallback=True,
)

return None

def _get_bounded_max_retries(
self,
max_retries: Optional[int],
Expand Down
93 changes: 92 additions & 1 deletion tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
from time import sleep
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch

import openai
import pytest

from langfuse._client.client import Langfuse
from langfuse._utils.prompt_cache import (
DEFAULT_PROMPT_CACHE_TTL_SECONDS,
PromptCache,
PromptCacheItem,
)
from langfuse.api.resources.prompts import Prompt_Chat, Prompt_Text
Expand Down Expand Up @@ -681,6 +683,9 @@ def test_prompt_end_to_end():
def langfuse():
langfuse_instance = Langfuse()
langfuse_instance.api = Mock()
langfuse_instance.async_api = Mock()
langfuse_instance._resources = Mock()
langfuse_instance._resources.prompt_cache = PromptCache()

return langfuse_instance

Expand Down Expand Up @@ -712,6 +717,92 @@ def test_get_fresh_prompt(langfuse):
assert result == TextPromptClient(prompt)


def test_async_get_fresh_prompt(langfuse):
langfuse._resources.prompt_cache.clear()
prompt_name = "test_async_get_fresh_prompt"
prompt = Prompt_Text(
name=prompt_name,
version=1,
prompt="Make me laugh",
type="text",
labels=[],
config={},
tags=[],
)

langfuse.async_api.prompts = Mock()
langfuse.async_api.prompts.get = AsyncMock(return_value=prompt)

result = asyncio.run(langfuse.aget_prompt(prompt_name, fallback="fallback"))

langfuse.async_api.prompts.get.assert_awaited_once_with(
prompt_name,
version=None,
label=None,
request_options=None,
)
assert result == TextPromptClient(prompt)

cache_key = langfuse._resources.prompt_cache.generate_cache_key(
prompt_name, version=None, label=None
)
cached_item = langfuse._resources.prompt_cache.get(cache_key)
assert cached_item is not None
assert cached_item.value == result


def test_async_get_prompt_uses_cache_without_fetch(langfuse):
langfuse._resources.prompt_cache.clear()
prompt_name = "test_async_get_prompt_uses_cache_without_fetch"
prompt = Prompt_Text(
name=prompt_name,
version=1,
prompt="Cached prompt",
type="text",
labels=[],
config={},
tags=[],
)
prompt_client = TextPromptClient(prompt)

cache_key = langfuse._resources.prompt_cache.generate_cache_key(
prompt_name, version=None, label=None
)
langfuse._resources.prompt_cache.set(cache_key, prompt_client, ttl_seconds=60)

langfuse.async_api.prompts = Mock()
langfuse.async_api.prompts.get = AsyncMock()

result = asyncio.run(langfuse.aget_prompt(prompt_name))

assert result == prompt_client
langfuse.async_api.prompts.get.assert_not_called()


def test_async_get_prompt_returns_fallback_on_failure(langfuse):
langfuse._resources.prompt_cache.clear()
prompt_name = "test_async_get_prompt_returns_fallback_on_failure"

langfuse.async_api.prompts = Mock()
langfuse.async_api.prompts.get = AsyncMock(side_effect=Exception("boom"))

result = asyncio.run(
langfuse.aget_prompt(
prompt_name, fallback="fallback text", max_retries=0
)
)

assert isinstance(result, TextPromptClient)
assert result.is_fallback is True
assert result.prompt == "fallback text"

cache_key = langfuse._resources.prompt_cache.generate_cache_key(
prompt_name, version=None, label=None
)
cached_item = langfuse._resources.prompt_cache.get(cache_key)
assert cached_item is None


# Should throw an error if prompt name is unspecified
def test_throw_if_name_unspecified(langfuse):
prompt_name = ""
Expand Down