99import time
1010import tiktoken
1111import re
12+ import diskcache as dc
13+ import hashlib
14+ import platformdirs
1215
1316from ._model import Chat , Instruct
1417from ._grammarless import GrammarlessEngine , Grammarless
@@ -180,8 +183,16 @@ class OpenAIChat(OpenAI, Chat):
180183 pass
181184
182185class OpenAIChatEngine (OpenAIEngine ):
183- def _generator (self , prompt , temperature ):
186+ def __init__ (self , * args , ** kwargs ):
187+ super ().__init__ (* args , ** kwargs )
188+ path = os .path .join (platformdirs .user_cache_dir ("guidance" ), "openai.tokens" )
189+ self .cache = dc .Cache (path )
190+
191+ def _hash_prompt (self , prompt ):
192+ return hashlib .sha256 (f"{ prompt } " .encode ()).hexdigest ()
184193
194+ def _generator (self , prompt , temperature ):
195+
185196 # find the role tags
186197 pos = 0
187198 role_end = b'<|im_end|>'
@@ -212,12 +223,23 @@ def _generator(self, prompt, temperature):
212223 raise ValueError (f"The OpenAI model { self .model_name } is a Chat-based model and requires role tags in the prompt! \
213224 Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
214225 to appropriately format your guidance program for this type of model." )
215-
216- # update our shared data state
226+
227+
228+ # Update shared data state
217229 self ._reset_shared_data (prompt [:pos ], temperature )
218230
231+ # Use cache only when temperature is 0
232+ if temperature == 0 :
233+ cache_key = self ._hash_prompt (prompt )
234+
235+ # Check if the result is already in the cache
236+ if cache_key in self .cache :
237+ for chunk in self .cache [cache_key ]:
238+ yield chunk
239+ return
240+
241+ # API call and response handling
219242 try :
220-
221243 generator = self .client .chat .completions .create (
222244 model = self .model_name ,
223245 messages = messages ,
@@ -227,11 +249,24 @@ def _generator(self, prompt, temperature):
227249 temperature = temperature ,
228250 stream = True
229251 )
252+
253+ if temperature == 0 :
254+ cached_results = []
255+
256+ for part in generator :
257+ if len (part .choices ) > 0 :
258+ chunk = part .choices [0 ].delta .content or ""
259+ else :
260+ chunk = ""
261+ encoded_chunk = chunk .encode ("utf8" )
262+ yield encoded_chunk
263+
264+ if temperature == 0 :
265+ cached_results .append (encoded_chunk )
266+
267+ # Cache the results after the generator is exhausted
268+ if temperature == 0 :
269+ self .cache [cache_key ] = cached_results
270+
230271 except Exception as e : # TODO: add retry logic
231272 raise e
232- for part in generator :
233- if len (part .choices ) > 0 :
234- chunk = part .choices [0 ].delta .content or ""
235- else :
236- chunk = ""
237- yield chunk .encode ("utf8" )
0 commit comments