Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 0bd9044

Browse files
committed
fix: test and implementation
1 parent 5318751 commit 0bd9044

File tree

3 files changed

+48
-132
lines changed

3 files changed

+48
-132
lines changed

literalai/api/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,15 @@ def handle_bytes(item):
143143

144144
class SharedCache:
145145
"""
146-
Thread-safe singleton cache for storing prompts with memory leak prevention.
146+
Singleton cache for storing data.
147147
Only one instance will exist regardless of how many times it's instantiated.
148148
"""
149149
_instance = None
150-
_cache: dict[str, Any]
151150

152151
def __new__(cls):
153152
if cls._instance is None:
154153
cls._instance = super().__new__(cls)
155-
cls._instance.cache = {}
154+
cls._instance._cache = {}
156155
return cls._instance
157156

158157
def get_cache(self) -> dict[str, Any]:
@@ -162,14 +161,23 @@ def get(self, key: str) -> Optional[Any]:
162161
"""
163162
Retrieves a value from the cache using the provided key.
164163
"""
164+
if not isinstance(key, str):
165+
raise TypeError("Key must be a string")
165166
return self._cache.get(key)
166167

167168
def put(self, key: str, value: Any):
168169
"""
169170
Stores a value in the cache.
170171
"""
172+
if not isinstance(key, str):
173+
raise TypeError("Key must be a string")
171174
self._cache[key] = value
172175

176+
def put_prompt(self, prompt: Prompt):
177+
self.put(prompt.id, prompt)
178+
self.put(prompt.name, prompt)
179+
self.put(f"{prompt.name}-{prompt.version}", prompt)
180+
173181
def clear(self) -> None:
174182
"""
175183
Clears all cached values.
@@ -1398,7 +1406,7 @@ def get_prompt(
13981406
raise ValueError("Either the `id` or the `name` must be provided.")
13991407

14001408
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
1401-
api=self,id=id, name=name, version=version, prompt_cache=self.prompt_cache
1409+
api=self,id=id, name=name, version=version, cache=self.cache
14021410
)
14031411

14041412
try:
@@ -1407,9 +1415,8 @@ def get_prompt(
14071415
elif name:
14081416
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)
14091417

1410-
self.cache.put(prompt.id, prompt)
1411-
self.cache.put(prompt.name, prompt)
1412-
self.cache.put(f"{prompt.name}-{prompt.version}", prompt)
1418+
self.cache.put_prompt(prompt)
1419+
14131420
return prompt
14141421

14151422
except Exception as e:
@@ -2645,7 +2652,7 @@ async def get_prompt(
26452652

26462653
sync_api = LiteralAPI(self.api_key, self.url)
26472654
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
2648-
api=sync_api, id=id, name=name, version=version, prompt_cache=self.prompt_cache
2655+
api=sync_api, id=id, name=name, version=version, cache=self.cache
26492656
)
26502657

26512658
try:
@@ -2658,7 +2665,6 @@ async def get_prompt(
26582665
get_prompt_query, description, variables, process_response, timeout
26592666
)
26602667

2661-
self.prompt_cache.put(prompt)
26622668
return prompt
26632669

26642670
except Exception as e:

literalai/api/prompt_helpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
if TYPE_CHECKING:
77
from literalai.api import LiteralAPI
8-
from literalai.api import SharedPromptCache
8+
from literalai.api import SharedCache
99

1010
from literalai.api import gql
1111

@@ -73,24 +73,24 @@ def get_prompt_helper(
7373
id: Optional[str] = None,
7474
name: Optional[str] = None,
7575
version: Optional[int] = 0,
76-
prompt_cache: Optional["SharedPromptCache"] = None,
76+
cache: Optional["SharedCache"] = None,
7777
) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]:
7878
"""Helper function for getting prompts with caching logic"""
7979

8080
cached_prompt = None
8181
timeout = 10
8282

83-
if prompt_cache:
84-
cached_prompt = prompt_cache.get(get_prompt_cache_key(id, name, version))
83+
if cache:
84+
cached_prompt = cache.get(get_prompt_cache_key(id, name, version))
8585
timeout = 1 if cached_prompt else timeout
8686

8787
variables = {"id": id, "name": name, "version": version}
8888

8989
def process_response(response):
9090
prompt_version = response["data"]["promptVersion"]
9191
prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None
92-
if prompt_cache:
93-
prompt_cache.put(prompt)
92+
if cache and prompt:
93+
cache.put_prompt(prompt)
9494
return prompt
9595

9696
description = "get prompt"

tests/unit/test_cache.py

Lines changed: 27 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import pytest
2-
from threading import Thread
3-
import time
4-
import random
52

63
from literalai.prompt_engineering.prompt import Prompt
7-
from literalai.api import SharedPromptCache, LiteralAPI
4+
from literalai.api import SharedCache, LiteralAPI
85

96
def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt:
107
return Prompt(
@@ -26,156 +23,69 @@ def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Promp
2623
)
2724

2825
def test_singleton_instance():
29-
"""Test that SharedPromptCache maintains singleton pattern"""
30-
cache1 = SharedPromptCache()
31-
cache2 = SharedPromptCache()
26+
"""Test that SharedCache maintains singleton pattern"""
27+
cache1 = SharedCache()
28+
cache2 = SharedCache()
3229
assert cache1 is cache2
3330

3431
def test_get_empty_cache():
3532
"""Test getting from empty cache returns None"""
36-
cache = SharedPromptCache()
33+
cache = SharedCache()
3734
cache.clear()
3835

39-
assert cache._prompts == {}
40-
assert cache._name_index == {}
41-
assert cache._name_version_index == {}
36+
assert cache.get_cache() == {}
4237

43-
def test_put_and_get_by_id():
44-
"""Test storing and retrieving prompt by ID"""
45-
cache = SharedPromptCache()
38+
def test_put_and_get_prompt_by_id_by_name_version_by_name():
39+
"""Test storing and retrieving prompt by ID by name-version by name"""
40+
cache = SharedCache()
4641
cache.clear()
4742

4843
prompt = default_prompt()
49-
cache.put(prompt)
44+
cache.put_prompt(prompt)
5045

51-
retrieved = cache.get(id="1")
52-
assert retrieved is prompt
53-
assert retrieved.id == "1"
54-
assert retrieved.name == "test"
55-
assert retrieved.version == 1
56-
57-
def test_put_and_get_by_name():
58-
"""Test storing and retrieving prompt by name"""
59-
cache = SharedPromptCache()
60-
cache.clear()
46+
retrieved_by_id = cache.get(id="1")
47+
assert retrieved_by_id is prompt
6148

62-
prompt = default_prompt()
63-
cache.put(prompt)
49+
retrieved_by_name_version = cache.get(name="test", version=1)
50+
assert retrieved_by_name_version is prompt
6451

65-
retrieved = cache.get(name="test")
66-
assert retrieved is prompt
67-
assert retrieved.name == "test"
68-
69-
def test_put_and_get_by_name_version():
70-
"""Test storing and retrieving prompt by name and version"""
71-
cache = SharedPromptCache()
72-
cache.clear()
73-
74-
prompt = default_prompt()
75-
cache.put(prompt)
76-
77-
retrieved = cache.get(name="test", version=1)
78-
assert retrieved is prompt
79-
assert retrieved.name == "test"
80-
assert retrieved.version == 1
81-
82-
def test_multiple_versions():
83-
"""Test handling multiple versions of the same prompt"""
84-
cache = SharedPromptCache()
85-
cache.clear()
86-
87-
prompt1 = default_prompt()
88-
prompt2 = default_prompt(id="2", version=2)
89-
90-
cache.put(prompt1)
91-
cache.put(prompt2)
92-
93-
assert cache.get(name="test", version=1) is prompt1
94-
assert cache.get(name="test", version=2) is prompt2
95-
96-
assert cache.get(name="test") is prompt2
52+
retrieved_by_name = cache.get(name="test")
53+
assert retrieved_by_name is prompt
9754

9855
def test_clear_cache():
9956
"""Test clearing the cache"""
100-
cache = SharedPromptCache()
57+
cache = SharedCache()
10158
prompt = default_prompt()
102-
cache.put(prompt)
59+
cache.put_prompt(prompt)
10360

10461
cache.clear()
105-
assert cache._prompts == {}
106-
assert cache._name_index == {}
107-
assert cache._name_version_index == {}
62+
assert cache.get_cache() == {}
10863

10964
def test_update_existing_prompt():
11065
"""Test updating an existing prompt"""
111-
cache = SharedPromptCache()
66+
cache = SharedCache()
11267
cache.clear()
11368

11469
prompt1 = default_prompt()
11570
prompt2 = default_prompt(id="1", version=2)
11671

117-
cache.put(prompt1)
118-
cache.put(prompt2)
72+
cache.put_prompt(prompt1)
73+
cache.put_prompt(prompt2)
11974

12075
retrieved = cache.get(id="1")
12176
assert retrieved is prompt2
12277
assert retrieved.version == 2
12378

124-
def test_lookup_priority():
125-
"""Test that lookup priority is id > name-version > name"""
126-
cache = SharedPromptCache()
127-
cache.clear()
128-
129-
prompt1 = default_prompt()
130-
prompt2 = default_prompt(id="2", name="test", version=2)
131-
132-
cache.put(prompt1)
133-
cache.put(prompt2)
134-
135-
assert cache.get(id="1", name="test", version=2) is prompt1
136-
137-
assert cache.get(name="test", version=2) is prompt2
138-
139-
def test_thread_safety():
140-
"""Test thread safety of the cache"""
141-
cache = SharedPromptCache()
142-
cache.clear()
143-
144-
def worker(worker_id: int):
145-
for i in range(100):
146-
prompt = default_prompt(
147-
id=f"{worker_id}-{i}",
148-
name=f"test-{worker_id}",
149-
version=i
150-
)
151-
cache.put(prompt)
152-
time.sleep(random.uniform(0, 0.001))
153-
154-
retrieved = cache.get(id=prompt.id)
155-
assert retrieved is prompt
156-
157-
threads = [Thread(target=worker, args=(i,)) for i in range(10)]
158-
159-
for t in threads:
160-
t.start()
161-
for t in threads:
162-
t.join()
163-
164-
for worker_id in range(10):
165-
for i in range(100):
166-
prompt_id = f"{worker_id}-{i}"
167-
assert cache.get(id=prompt_id) is not None
168-
16979
def test_error_handling():
17080
"""Test error handling for invalid inputs"""
171-
cache = SharedPromptCache()
81+
cache = SharedCache()
17282
cache.clear()
17383

174-
assert cache.get() is None
175-
assert cache.get(id=None, name=None, version=None) is None
84+
assert cache.get_cache() == {}
85+
assert cache.get(key="") is None
17686

17787
with pytest.raises(TypeError):
178-
cache.get(version="invalid") # type: ignore
88+
cache.get(5) # type: ignore
17989

18090
with pytest.raises(TypeError):
181-
cache.put("not a prompt") # type: ignore
91+
cache.put(5, "test") # type: ignore

0 commit comments

Comments
 (0)