Skip to content

Commit 496bc53

Browse files
authored
Merge pull request #603 from adamgordonbell/agb/caching
Adding Caching to OpenAI chat
2 parents 4187419 + 9c10bd0 commit 496bc53

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

guidance/models/_openai.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import time
1010
import tiktoken
1111
import re
12+
import diskcache as dc
13+
import hashlib
14+
import platformdirs
1215

1316
from ._model import Chat, Instruct
1417
from ._grammarless import GrammarlessEngine, Grammarless
@@ -180,8 +183,16 @@ class OpenAIChat(OpenAI, Chat):
180183
pass
181184

182185
class OpenAIChatEngine(OpenAIEngine):
183-
def _generator(self, prompt, temperature):
186+
def __init__(self, *args, **kwargs):
187+
super().__init__(*args, **kwargs)
188+
path = os.path.join(platformdirs.user_cache_dir("guidance"), "openai.tokens")
189+
self.cache = dc.Cache(path)
190+
191+
def _hash_prompt(self, prompt):
192+
return hashlib.sha256(f"{prompt}".encode()).hexdigest()
184193

194+
def _generator(self, prompt, temperature):
195+
185196
# find the role tags
186197
pos = 0
187198
role_end = b'<|im_end|>'
@@ -212,12 +223,23 @@ def _generator(self, prompt, temperature):
212223
raise ValueError(f"The OpenAI model {self.model_name} is a Chat-based model and requires role tags in the prompt! \
213224
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
214225
to appropriately format your guidance program for this type of model.")
215-
216-
# update our shared data state
226+
227+
228+
# Update shared data state
217229
self._reset_shared_data(prompt[:pos], temperature)
218230

231+
# Use cache only when temperature is 0
232+
if temperature == 0:
233+
cache_key = self._hash_prompt(prompt)
234+
235+
# Check if the result is already in the cache
236+
if cache_key in self.cache:
237+
for chunk in self.cache[cache_key]:
238+
yield chunk
239+
return
240+
241+
# API call and response handling
219242
try:
220-
221243
generator = self.client.chat.completions.create(
222244
model=self.model_name,
223245
messages=messages,
@@ -227,11 +249,24 @@ def _generator(self, prompt, temperature):
227249
temperature=temperature,
228250
stream=True
229251
)
252+
253+
if temperature == 0:
254+
cached_results = []
255+
256+
for part in generator:
257+
if len(part.choices) > 0:
258+
chunk = part.choices[0].delta.content or ""
259+
else:
260+
chunk = ""
261+
encoded_chunk = chunk.encode("utf8")
262+
yield encoded_chunk
263+
264+
if temperature == 0:
265+
cached_results.append(encoded_chunk)
266+
267+
# Cache the results after the generator is exhausted
268+
if temperature == 0:
269+
self.cache[cache_key] = cached_results
270+
230271
except Exception as e: # TODO: add retry logic
231272
raise e
232-
for part in generator:
233-
if len(part.choices) > 0:
234-
chunk = part.choices[0].delta.content or ""
235-
else:
236-
chunk = ""
237-
yield chunk.encode("utf8")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ requires = [
44
"wheel",
55
"pybind11>=2.10.0",
66
]
7-
build-backend = "setuptools.build_meta"
7+
build-backend = "setuptools.build_meta"

0 commit comments

Comments
 (0)