Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 5318751

Browse files
committed
refactor: finishes the simplication
1 parent 3730581 commit 5318751

File tree

2 files changed

+32
-59
lines changed

2 files changed

+32
-59
lines changed

literalai/api/__init__.py

Lines changed: 20 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import os
3-
from threading import Lock
43
import uuid
54
from typing import (
65
Any,
@@ -142,77 +141,40 @@ def handle_bytes(item):
142141
return handle_bytes(variables)
143142

144143

145-
class SharedPromptCache:
144+
class SharedCache:
146145
"""
147146
Thread-safe singleton cache for storing prompts with memory leak prevention.
148147
Only one instance will exist regardless of how many times it's instantiated.
149148
"""
150149
_instance = None
151-
_lock = Lock()
152-
_prompts: dict[str, Prompt]
153-
_name_index: dict[str, str]
154-
_name_version_index: dict[tuple[str, int], str]
150+
_cache: dict[str, Any]
155151

156152
def __new__(cls):
157-
with cls._lock:
158-
if cls._instance is None:
159-
cls._instance = super().__new__(cls)
160-
161-
cls._instance._prompts = {}
162-
cls._instance._name_index = {}
163-
cls._instance._name_version_index = {}
153+
if cls._instance is None:
154+
cls._instance = super().__new__(cls)
155+
cls._instance.cache = {}
164156
return cls._instance
165157

166-
def get(
167-
self,
168-
id: Optional[str] = None,
169-
name: Optional[str] = None,
170-
version: Optional[int] = None
171-
) -> Optional[Prompt]:
158+
def get_cache(self) -> dict[str, Any]:
159+
return self._cache
160+
161+
def get(self, key: str) -> Optional[Any]:
172162
"""
173-
Retrieves a prompt using the most specific criteria provided.
174-
Lookup priority: id, name-version, name
163+
Retrieves a value from the cache using the provided key.
175164
"""
176-
if id and not isinstance(id, str):
177-
raise TypeError("Expected a string for id")
178-
if name and not isinstance(name, str):
179-
raise TypeError("Expected a string for name")
180-
if version and not isinstance(version, int):
181-
raise TypeError("Expected an integer for version")
165+
return self._cache.get(key)
182166

183-
if id:
184-
prompt_id = id
185-
elif name and version:
186-
prompt_id = self._name_version_index.get((name, version)) or ""
187-
elif name:
188-
prompt_id = self._name_index.get(name) or ""
189-
else:
190-
return None
191-
192-
if prompt_id and prompt_id in self._prompts:
193-
return self._prompts.get(prompt_id)
194-
return None
195-
196-
def put(self, prompt: Prompt):
167+
def put(self, key: str, value: Any):
197168
"""
198-
Stores a prompt in the cache.
169+
Stores a value in the cache.
199170
"""
200-
if not isinstance(prompt, Prompt):
201-
raise TypeError("Expected a Prompt object")
202-
203-
with self._lock:
204-
self._prompts[prompt.id] = prompt
205-
self._name_index[prompt.name] = prompt.id
206-
self._name_version_index[(prompt.name, prompt.version)] = prompt.id
171+
self._cache[key] = value
207172

208173
def clear(self) -> None:
209174
"""
210-
Clears all cached promopts and indices.
175+
Clears all cached values.
211176
"""
212-
with self._lock:
213-
self._prompts.clear()
214-
self._name_index.clear()
215-
self._name_version_index.clear()
177+
self._cache.clear()
216178

217179

218180
class BaseLiteralAPI:
@@ -239,7 +201,7 @@ def __init__(
239201
self.graphql_endpoint = self.url + "/api/graphql"
240202
self.rest_endpoint = self.url + "/api"
241203

242-
self.prompt_cache = SharedPromptCache()
204+
self.cache = SharedCache()
243205

244206
@property
245207
def headers(self):
@@ -1445,7 +1407,9 @@ def get_prompt(
14451407
elif name:
14461408
prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout)
14471409

1448-
self.prompt_cache.put(prompt)
1410+
self.cache.put(prompt.id, prompt)
1411+
self.cache.put(prompt.name, prompt)
1412+
self.cache.put(f"{prompt.name}-{prompt.version}", prompt)
14491413
return prompt
14501414

14511415
except Exception as e:

literalai/api/prompt_helpers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ def process_response(response):
5757
return gql.CREATE_PROMPT_VERSION, description, variables, process_response
5858

5959

60+
def get_prompt_cache_key(id: Optional[str], name: Optional[str], version: Optional[int]) -> str:
61+
if id:
62+
return id
63+
elif name and version:
64+
return f"{name}-{version}"
65+
elif name:
66+
return name
67+
else:
68+
raise ValueError("Either the `id` or the `name` must be provided.")
69+
70+
6071
def get_prompt_helper(
6172
api: "LiteralAPI",
6273
id: Optional[str] = None,
@@ -65,14 +76,12 @@ def get_prompt_helper(
6576
prompt_cache: Optional["SharedPromptCache"] = None,
6677
) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]:
6778
"""Helper function for getting prompts with caching logic"""
68-
if not (id or name):
69-
raise ValueError("Either the `id` or the `name` must be provided.")
7079

7180
cached_prompt = None
7281
timeout = 10
7382

7483
if prompt_cache:
75-
cached_prompt = prompt_cache.get(id, name, version)
84+
cached_prompt = prompt_cache.get(get_prompt_cache_key(id, name, version))
7685
timeout = 1 if cached_prompt else timeout
7786

7887
variables = {"id": id, "name": name, "version": version}

0 commit comments

Comments
 (0)