Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions guidance/models/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import time
import tiktoken
import re
import diskcache as dc
import hashlib
import platformdirs

from ._model import Chat, Instruct
from ._grammarless import GrammarlessEngine, Grammarless
Expand Down Expand Up @@ -180,8 +183,16 @@ class OpenAIChat(OpenAI, Chat):
pass

class OpenAIChatEngine(OpenAIEngine):
def _generator(self, prompt, temperature):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
path = os.path.join(platformdirs.user_cache_dir("guidance"), "openai.tokens")
self.cache = dc.Cache(path)

def _hash_prompt(self, prompt):
return hashlib.sha256(f"{prompt}".encode()).hexdigest()

def _generator(self, prompt, temperature):

# find the role tags
pos = 0
role_end = b'<|im_end|>'
Expand Down Expand Up @@ -212,12 +223,23 @@ def _generator(self, prompt, temperature):
raise ValueError(f"The OpenAI model {self.model_name} is a Chat-based model and requires role tags in the prompt! \
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
to appropriately format your guidance program for this type of model.")

# update our shared data state


# Update shared data state
self._reset_shared_data(prompt[:pos], temperature)

# Use cache only when temperature is 0
if temperature == 0:
cache_key = self._hash_prompt(prompt)

# Check if the result is already in the cache
if cache_key in self.cache:
for chunk in self.cache[cache_key]:
yield chunk
return

# API call and response handling
try:

generator = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
Expand All @@ -227,11 +249,24 @@ def _generator(self, prompt, temperature):
temperature=temperature,
stream=True
)

if temperature == 0:
cached_results = []

for part in generator:
if len(part.choices) > 0:
chunk = part.choices[0].delta.content or ""
else:
chunk = ""
encoded_chunk = chunk.encode("utf8")
yield encoded_chunk

if temperature == 0:
cached_results.append(encoded_chunk)

# Cache the results after the generator is exhausted
if temperature == 0:
self.cache[cache_key] = cached_results

except Exception as e: # TODO: add retry logic
raise e
for part in generator:
if len(part.choices) > 0:
chunk = part.choices[0].delta.content or ""
else:
chunk = ""
yield chunk.encode("utf8")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ requires = [
"wheel",
"pybind11>=2.10.0",
]
build-backend = "setuptools.build_meta"
build-backend = "setuptools.build_meta"