diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index bab92e616..3fa4d4e6a 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -9,6 +9,9 @@ import time import tiktoken import re +import diskcache as dc +import hashlib +import platformdirs from ._model import Chat, Instruct from ._grammarless import GrammarlessEngine, Grammarless @@ -180,8 +183,16 @@ class OpenAIChat(OpenAI, Chat): pass class OpenAIChatEngine(OpenAIEngine): - def _generator(self, prompt, temperature): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + path = os.path.join(platformdirs.user_cache_dir("guidance"), "openai.tokens") + self.cache = dc.Cache(path) + + def _hash_prompt(self, prompt): + return hashlib.sha256(f"{prompt}".encode()).hexdigest() + def _generator(self, prompt, temperature): + # find the role tags pos = 0 role_end = b'<|im_end|>' @@ -212,12 +223,23 @@ def _generator(self, prompt, temperature): raise ValueError(f"The OpenAI model {self.model_name} is a Chat-based model and requires role tags in the prompt! \ Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \ to appropriately format your guidance program for this type of model.") - - # update our shared data state + + + # Update shared data state self._reset_shared_data(prompt[:pos], temperature) + # Use cache only when temperature is 0 + if temperature == 0: + cache_key = self._hash_prompt(prompt) + + # Check if the result is already in the cache + if cache_key in self.cache: + for chunk in self.cache[cache_key]: + yield chunk + return + + # API call and response handling try: - generator = self.client.chat.completions.create( model=self.model_name, messages=messages, @@ -227,11 +249,24 @@ def _generator(self, prompt, temperature): temperature=temperature, stream=True ) + + if temperature == 0: + cached_results = [] + + for part in generator: + if len(part.choices) > 0: + chunk = part.choices[0].delta.content or "" + else: + chunk = "" + encoded_chunk = chunk.encode("utf8") + yield encoded_chunk + + if temperature == 0: + cached_results.append(encoded_chunk) + + # Cache the results after the generator is exhausted + if temperature == 0: + self.cache[cache_key] = cached_results + except Exception as e: # TODO: add retry logic raise e - for part in generator: - if len(part.choices) > 0: - chunk = part.choices[0].delta.content or "" - else: - chunk = "" - yield chunk.encode("utf8") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c81285adf..8c55f4ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,4 +4,4 @@ requires = [ "wheel", "pybind11>=2.10.0", ] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta"