Skip to content

Commit f075f80

Browse files
Adding caching to openai chat
1 parent d36601b commit f075f80

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

guidance/models/_openai.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import time
1010
import tiktoken
1111
import re
12+
import diskcache as dc
13+
import hashlib
14+
import platformdirs
1215

1316
from ._model import Chat, Instruct
1417
from ._remote import Remote
@@ -158,9 +161,18 @@ def _generator(self, prompt, temperature):
158161
chunk = ""
159162
yield chunk.encode("utf8")
160163

164+
161165
class OAIChatMixin(Chat):
162-
def _generator(self, prompt, temperature):
166+
def __init__(self, *args, **kwargs):
167+
super().__init__(*args, **kwargs)
168+
path = os.path.join(platformdirs.user_cache_dir("guidance"), "openai.tokens")
169+
self.cache = dc.Cache(path)
170+
171+
def _hash_prompt(self, prompt):
172+
return hashlib.sha256(f"{prompt}".encode()).hexdigest()
163173

174+
def _generator(self, prompt, temperature):
175+
164176
# find the role tags
165177
pos = 0
166178
role_end = b'<|im_end|>'
@@ -191,12 +203,24 @@ def _generator(self, prompt, temperature):
191203
raise ValueError(f"The OpenAI model {self.model_name} is a Chat-based model and requires role tags in the prompt! \
192204
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
193205
to appropriately format your guidance program for this type of model.")
194-
195-
# update our shared data state
206+
207+
208+
# Update shared data state
196209
self._reset_shared_data(prompt[:pos], temperature)
197210

211+
# Use cache only when temperature is 0
212+
if temperature == 0:
213+
cache_key = self._hash_prompt(prompt)
214+
215+
# Check if the result is already in the cache
216+
if cache_key in self.cache:
217+
print("cache hit")
218+
for chunk in self.cache[cache_key]:
219+
yield chunk
220+
return
221+
222+
# API call and response handling
198223
try:
199-
200224
generator = self.client.chat.completions.create(
201225
model=self.model_name,
202226
messages=messages,
@@ -206,14 +230,27 @@ def _generator(self, prompt, temperature):
206230
temperature=temperature,
207231
stream=True
208232
)
209-
except Exception as e: # TODO: add retry logic
233+
234+
if temperature == 0:
235+
cached_results = []
236+
237+
for part in generator:
238+
if len(part.choices) > 0:
239+
chunk = part.choices[0].delta.content or ""
240+
else:
241+
chunk = ""
242+
encoded_chunk = chunk.encode("utf8")
243+
yield encoded_chunk
244+
245+
if temperature == 0:
246+
cached_results.append(encoded_chunk)
247+
248+
# Cache the results after the generator is exhausted
249+
if temperature == 0:
250+
self.cache[cache_key] = cached_results
251+
252+
except Exception as e:
210253
raise e
211-
for part in generator:
212-
if len(part.choices) > 0:
213-
chunk = part.choices[0].delta.content or ""
214-
else:
215-
chunk = ""
216-
yield chunk.encode("utf8")
217254

218255
class OpenAICompletion(OpenAI, OAICompletionMixin):
219256
def __init__(self, *args, **kwargs):
@@ -225,4 +262,4 @@ def __init__(self, *args, **kwargs):
225262

226263
class OpenAIChat(OpenAI, OAIChatMixin):
227264
def __init__(self, *args, **kwargs):
228-
super().__init__(*args, **kwargs)
265+
super().__init__(*args, **kwargs)

pyproject.toml

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,34 @@ requires = [
44
"wheel",
55
"pybind11>=2.10.0",
66
]
7-
build-backend = "setuptools.build_meta"
7+
build-backend = "setuptools.build_meta"
8+
9+
10+
[tool.poetry]
11+
name = "guidance"
12+
homepage = "https://github.com/guidance-ai/guidance"
13+
14+
[tool.poetry.dependencies]
15+
python = ">=3.8"
16+
diskcache = "*"
17+
gptcache = "*"
18+
openai = ">=1.0"
19+
platformdirs = "*"
20+
tiktoken = ">=0.3"
21+
msal = "*"
22+
requests = "*"
23+
numpy = "*"
24+
aiohttp = "*"
25+
ordered-set = "*"
26+
pyformlang = "*"
27+
28+
[tool.poetry.dev-dependencies]
29+
ipython = "*"
30+
numpydoc = "*"
31+
sphinx_rtd_theme = "*"
32+
sphinx = "*"
33+
nbsphinx = "*"
34+
pytest = "*"
35+
transformers = "*"
36+
torch = "*"
37+
pytest-cov = "*"

0 commit comments

Comments
 (0)