99import time
1010import tiktoken
1111import re
12+ import diskcache as dc
13+ import hashlib
14+ import platformdirs
1215
1316from ._model import Chat , Instruct
1417from ._remote import Remote
@@ -158,9 +161,18 @@ def _generator(self, prompt, temperature):
158161 chunk = ""
159162 yield chunk .encode ("utf8" )
160163
164+
161165class OAIChatMixin (Chat ):
162- def _generator (self , prompt , temperature ):
166+ def __init__ (self , * args , ** kwargs ):
167+ super ().__init__ (* args , ** kwargs )
168+ path = os .path .join (platformdirs .user_cache_dir ("guidance" ), "openai.tokens" )
169+ self .cache = dc .Cache (path )
170+
171+ def _hash_prompt (self , prompt ):
172+ return hashlib .sha256 (f"{ prompt } " .encode ()).hexdigest ()
163173
174+ def _generator (self , prompt , temperature ):
175+
164176 # find the role tags
165177 pos = 0
166178 role_end = b'<|im_end|>'
@@ -191,12 +203,24 @@ def _generator(self, prompt, temperature):
191203 raise ValueError (f"The OpenAI model { self .model_name } is a Chat-based model and requires role tags in the prompt! \
192204 Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
193205 to appropriately format your guidance program for this type of model." )
194-
195- # update our shared data state
206+
207+
208+ # Update shared data state
196209 self ._reset_shared_data (prompt [:pos ], temperature )
197210
211+ # Use cache only when temperature is 0
212+ if temperature == 0 :
213+ cache_key = self ._hash_prompt (prompt )
214+
215+ # Check if the result is already in the cache
216+ if cache_key in self .cache :
217+ print ("cache hit" )
218+ for chunk in self .cache [cache_key ]:
219+ yield chunk
220+ return
221+
222+ # API call and response handling
198223 try :
199-
200224 generator = self .client .chat .completions .create (
201225 model = self .model_name ,
202226 messages = messages ,
@@ -206,14 +230,27 @@ def _generator(self, prompt, temperature):
206230 temperature = temperature ,
207231 stream = True
208232 )
209- except Exception as e : # TODO: add retry logic
233+
234+ if temperature == 0 :
235+ cached_results = []
236+
237+ for part in generator :
238+ if len (part .choices ) > 0 :
239+ chunk = part .choices [0 ].delta .content or ""
240+ else :
241+ chunk = ""
242+ encoded_chunk = chunk .encode ("utf8" )
243+ yield encoded_chunk
244+
245+ if temperature == 0 :
246+ cached_results .append (encoded_chunk )
247+
248+ # Cache the results after the generator is exhausted
249+ if temperature == 0 :
250+ self .cache [cache_key ] = cached_results
251+
252+ except Exception as e :
210253 raise e
211- for part in generator :
212- if len (part .choices ) > 0 :
213- chunk = part .choices [0 ].delta .content or ""
214- else :
215- chunk = ""
216- yield chunk .encode ("utf8" )
217254
218255class OpenAICompletion (OpenAI , OAICompletionMixin ):
219256 def __init__ (self , * args , ** kwargs ):
@@ -225,4 +262,4 @@ def __init__(self, *args, **kwargs):
225262
226263class OpenAIChat (OpenAI , OAIChatMixin ):
227264 def __init__ (self , * args , ** kwargs ):
228- super ().__init__ (* args , ** kwargs )
265+ super ().__init__ (* args , ** kwargs )
0 commit comments