11import pytest
2- from threading import Thread
3- import time
4- import random
52
63from literalai .prompt_engineering .prompt import Prompt
7- from literalai .api import SharedPromptCache , LiteralAPI
4+ from literalai .api import SharedCache , LiteralAPI
85
96def default_prompt (id : str = "1" , name : str = "test" , version : int = 1 ) -> Prompt :
107 return Prompt (
@@ -26,156 +23,69 @@ def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Promp
2623 )
2724
2825def test_singleton_instance ():
29- """Test that SharedPromptCache maintains singleton pattern"""
30- cache1 = SharedPromptCache ()
31- cache2 = SharedPromptCache ()
26+ """Test that SharedCache maintains singleton pattern"""
27+ cache1 = SharedCache ()
28+ cache2 = SharedCache ()
3229 assert cache1 is cache2
3330
3431def test_get_empty_cache ():
3532 """Test getting from empty cache returns None"""
36- cache = SharedPromptCache ()
33+ cache = SharedCache ()
3734 cache .clear ()
3835
39- assert cache ._prompts == {}
40- assert cache ._name_index == {}
41- assert cache ._name_version_index == {}
36+ assert cache .get_cache () == {}
4237
43- def test_put_and_get_by_id ():
44- """Test storing and retrieving prompt by ID"""
45- cache = SharedPromptCache ()
38+ def test_put_and_get_prompt_by_id_by_name_version_by_name ():
39+ """Test storing and retrieving prompt by ID by name-version by name """
40+ cache = SharedCache ()
4641 cache .clear ()
4742
4843 prompt = default_prompt ()
49- cache .put (prompt )
44+ cache .put_prompt (prompt )
5045
51- retrieved = cache .get (id = "1" )
52- assert retrieved is prompt
53- assert retrieved .id == "1"
54- assert retrieved .name == "test"
55- assert retrieved .version == 1
56-
57- def test_put_and_get_by_name ():
58- """Test storing and retrieving prompt by name"""
59- cache = SharedPromptCache ()
60- cache .clear ()
46+ retrieved_by_id = cache .get (id = "1" )
47+ assert retrieved_by_id is prompt
6148
62- prompt = default_prompt ( )
63- cache . put ( prompt )
49+ retrieved_by_name_version = cache . get ( name = "test" , version = 1 )
50+ assert retrieved_by_name_version is prompt
6451
65- retrieved = cache .get (name = "test" )
66- assert retrieved is prompt
67- assert retrieved .name == "test"
68-
69- def test_put_and_get_by_name_version ():
70- """Test storing and retrieving prompt by name and version"""
71- cache = SharedPromptCache ()
72- cache .clear ()
73-
74- prompt = default_prompt ()
75- cache .put (prompt )
76-
77- retrieved = cache .get (name = "test" , version = 1 )
78- assert retrieved is prompt
79- assert retrieved .name == "test"
80- assert retrieved .version == 1
81-
82- def test_multiple_versions ():
83- """Test handling multiple versions of the same prompt"""
84- cache = SharedPromptCache ()
85- cache .clear ()
86-
87- prompt1 = default_prompt ()
88- prompt2 = default_prompt (id = "2" , version = 2 )
89-
90- cache .put (prompt1 )
91- cache .put (prompt2 )
92-
93- assert cache .get (name = "test" , version = 1 ) is prompt1
94- assert cache .get (name = "test" , version = 2 ) is prompt2
95-
96- assert cache .get (name = "test" ) is prompt2
52+ retrieved_by_name = cache .get (name = "test" )
53+ assert retrieved_by_name is prompt
9754
9855def test_clear_cache ():
9956 """Test clearing the cache"""
100- cache = SharedPromptCache ()
57+ cache = SharedCache ()
10158 prompt = default_prompt ()
102- cache .put (prompt )
59+ cache .put_prompt (prompt )
10360
10461 cache .clear ()
105- assert cache ._prompts == {}
106- assert cache ._name_index == {}
107- assert cache ._name_version_index == {}
62+ assert cache .get_cache () == {}
10863
10964def test_update_existing_prompt ():
11065 """Test updating an existing prompt"""
111- cache = SharedPromptCache ()
66+ cache = SharedCache ()
11267 cache .clear ()
11368
11469 prompt1 = default_prompt ()
11570 prompt2 = default_prompt (id = "1" , version = 2 )
11671
117- cache .put (prompt1 )
118- cache .put (prompt2 )
72+ cache .put_prompt (prompt1 )
73+ cache .put_prompt (prompt2 )
11974
12075 retrieved = cache .get (id = "1" )
12176 assert retrieved is prompt2
12277 assert retrieved .version == 2
12378
124- def test_lookup_priority ():
125- """Test that lookup priority is id > name-version > name"""
126- cache = SharedPromptCache ()
127- cache .clear ()
128-
129- prompt1 = default_prompt ()
130- prompt2 = default_prompt (id = "2" , name = "test" , version = 2 )
131-
132- cache .put (prompt1 )
133- cache .put (prompt2 )
134-
135- assert cache .get (id = "1" , name = "test" , version = 2 ) is prompt1
136-
137- assert cache .get (name = "test" , version = 2 ) is prompt2
138-
139- def test_thread_safety ():
140- """Test thread safety of the cache"""
141- cache = SharedPromptCache ()
142- cache .clear ()
143-
144- def worker (worker_id : int ):
145- for i in range (100 ):
146- prompt = default_prompt (
147- id = f"{ worker_id } -{ i } " ,
148- name = f"test-{ worker_id } " ,
149- version = i
150- )
151- cache .put (prompt )
152- time .sleep (random .uniform (0 , 0.001 ))
153-
154- retrieved = cache .get (id = prompt .id )
155- assert retrieved is prompt
156-
157- threads = [Thread (target = worker , args = (i ,)) for i in range (10 )]
158-
159- for t in threads :
160- t .start ()
161- for t in threads :
162- t .join ()
163-
164- for worker_id in range (10 ):
165- for i in range (100 ):
166- prompt_id = f"{ worker_id } -{ i } "
167- assert cache .get (id = prompt_id ) is not None
168-
16979def test_error_handling ():
17080 """Test error handling for invalid inputs"""
171- cache = SharedPromptCache ()
81+ cache = SharedCache ()
17282 cache .clear ()
17383
174- assert cache .get () is None
175- assert cache .get (id = None , name = None , version = None ) is None
84+ assert cache .get_cache () == {}
85+ assert cache .get (key = "" ) is None
17686
17787 with pytest .raises (TypeError ):
178- cache .get (version = "invalid" ) # type: ignore
88+ cache .get (5 ) # type: ignore
17989
18090 with pytest .raises (TypeError ):
181- cache .put ("not a prompt " ) # type: ignore
91+ cache .put (5 , "test " ) # type: ignore
0 commit comments