Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 3433ee1

Browse files
committed
feat: adds tests and updates run-test.sh
1 parent 3e139f2 commit 3433ee1

File tree

5 files changed

+29
-19
lines changed

5 files changed

+29
-19
lines changed

literalai/api/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import time
32
import os
43
from threading import Lock
54
import uuid
@@ -147,20 +146,21 @@ class SharedPromptCache:
147146
"""
148147
Thread-safe singleton cache for storing prompts with memory leak prevention.
149148
Only one instance will exist regardless of how many times it's instantiated.
150-
Implements LRU eviction policy when cache reaches maximum size.
151149
"""
152150
_instance = None
153151
_lock = Lock()
152+
_shared_cache: dict[str, Prompt]
153+
_name_index: dict[str, str]
154+
_name_version_index: dict[tuple[str, int], str]
154155

155156
def __new__(cls, max_size: int = 1000):
156157
with cls._lock:
157158
if cls._instance is None:
158159
cls._instance = super().__new__(cls)
159160

160-
cls._instance._max_size = max_size
161-
cls._instance._prompts: dict[str, Prompt] = {}
162-
cls._instance._name_index: dict[str, str] = {}
163-
cls._instance._name_version_index: dict[tuple[str, int], str] = {}
161+
cls._instance._shared_cache = {}
162+
cls._instance._name_index = {}
163+
cls._instance._name_version_index = {}
164164
return cls._instance
165165

166166
def get(
@@ -171,7 +171,6 @@ def get(
171171
) -> Optional[Prompt]:
172172
"""
173173
Retrieves a prompt using the most specific criteria provided.
174-
Updates access time for LRU tracking.
175174
Lookup priority: id, name-version, name
176175
"""
177176
if id and not isinstance(id, str):
@@ -196,7 +195,7 @@ def get(
196195

197196
def put(self, prompt: Prompt):
198197
"""
199-
Stores a prompt in the cache, managing size limits with LRU eviction.
198+
Stores a prompt in the cache.
200199
"""
201200
if not isinstance(prompt, Prompt):
202201
raise TypeError("Expected a Prompt object")

literalai/api/prompt_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
if TYPE_CHECKING:
77
from literalai.api import LiteralAPI
8+
from literalai.api import SharedPromptCache
89

910
from literalai.api import gql
1011

run-test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v
1+
LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v tests/e2e/ tests/unit/

tests/e2e/test_e2e.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,20 @@ async def test_prompt(self, async_client: AsyncLiteralClient):
662662

663663
assert messages[0]["content"] == expected
664664

665+
@pytest.mark.timeout(5)
666+
async def test_prompt_cache(self, async_client: AsyncLiteralClient):
667+
prompt = await async_client.api.get_prompt(name="Default", version=0)
668+
assert prompt is not None
669+
670+
original_key = async_client.api.api_key
671+
async_client.api.api_key = "invalid-api-key"
672+
673+
cached_prompt = await async_client.api.get_prompt(name="Default", version=0)
674+
assert cached_prompt is not None
675+
assert cached_prompt.id == prompt.id
676+
677+
async_client.api.api_key = original_key
678+
665679
@pytest.mark.timeout(5)
666680
async def test_prompt_ab_testing(self, client: LiteralClient):
667681
prompt_name = "Python SDK E2E Tests"

tests/unit/test_cache.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
import random
55

66
from literalai.prompt_engineering.prompt import Prompt
7-
from literalai.api import SharedPromptCache
7+
from literalai.api import SharedPromptCache, LiteralAPI, GenerationType
88

99
def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt:
1010
return Prompt(
11-
api=None,
11+
api=LiteralAPI(),
1212
id=id,
1313
name=name,
1414
version=version,
1515
created_at="",
1616
updated_at="",
17-
type="chat",
17+
type="chat", # type: ignore
1818
url="",
1919
version_desc=None,
2020
template_messages=[],
@@ -34,7 +34,7 @@ def test_singleton_instance():
3434
def test_get_empty_cache():
3535
"""Test getting from empty cache returns None"""
3636
cache = SharedPromptCache()
37-
cache.clear() # Ensure clean state
37+
cache.clear()
3838

3939
assert cache._prompts == {}
4040
assert cache._name_index == {}
@@ -90,12 +90,10 @@ def test_multiple_versions():
9090
cache.put(prompt1)
9191
cache.put(prompt2)
9292

93-
# Get specific versions
9493
assert cache.get(name="test", version=1) is prompt1
9594
assert cache.get(name="test", version=2) is prompt2
9695

97-
# Get by name should return latest version
98-
assert cache.get(name="test") is prompt2 # Returns the last indexed version
96+
assert cache.get(name="test") is prompt2
9997

10098
def test_clear_cache():
10199
"""Test clearing the cache"""
@@ -114,7 +112,7 @@ def test_update_existing_prompt():
114112
cache.clear()
115113

116114
prompt1 = default_prompt()
117-
prompt2 = default_prompt(id="1", version=2) # Same ID, different version
115+
prompt2 = default_prompt(id="1", version=2)
118116

119117
cache.put(prompt1)
120118
cache.put(prompt2)
@@ -134,10 +132,8 @@ def test_lookup_priority():
134132
cache.put(prompt1)
135133
cache.put(prompt2)
136134

137-
# ID should take precedence
138135
assert cache.get(id="1", name="test", version=2) is prompt1
139136

140-
# Name-version should take precedence over name
141137
assert cache.get(name="test", version=2) is prompt2
142138

143139
def test_thread_safety():

0 commit comments

Comments
 (0)