Skip to content

Commit 58d2316

Browse files
authored
OpenAI v1 support (#441)
* Support both openai v0 and v1 * Adjust CI for openai v0/v1 * Move LLMResponse object to its own file * Add usage to provenance test mock * Comment out ArbitraryCallable test suite for openai errors * Adjust tests for openai v0/v1 * remove unused imports * OpenAIv1: Don't instantiate openai.completions.create if key isn't present in environ
1 parent 5f03aae commit 58d2316

33 files changed

+805
-278
lines changed

.github/workflows/ci.yml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
matrix:
5252
python-version: ['3.8', '3.9', '3.10', '3.11']
5353
pydantic-version: ['1.10.9', '2.4.2']
54+
openai-version: ['0.28.1', '1.2.4']
5455
steps:
5556
- uses: actions/checkout@v2
5657
- name: Set up Python ${{ matrix.python-version }}
@@ -72,16 +73,27 @@ jobs:
7273
run: |
7374
make full
7475
poetry run pip install pydantic==${{ matrix.pydantic-version }}
76+
poetry run pip install openai==${{ matrix.openai-version }}
7577
76-
- if: matrix.pydantic-version == '2.4.2'
77-
name: Static analysis with pyright (ignoring pydantic v1)
78+
- if: matrix.pydantic-version == '2.4.2' && matrix.openai-version == '0.28.1'
79+
name: Static analysis with pyright (ignoring pydantic v1 and openai v1)
7880
run: |
79-
make type-pydantic-v2
81+
make type-pydantic-v2-openai-v0
8082
81-
- if: matrix.pydantic-version == '1.10.9'
82-
name: Static analysis with mypy (ignoring pydantic v2)
83+
- if: matrix.pydantic-version == '1.10.9' && matrix.openai-version == '0.28.1'
84+
name: Static analysis with mypy (ignoring pydantic v2 and openai v1)
8385
run: |
84-
make type-pydantic-v1
86+
make type-pydantic-v1-openai-v0
87+
88+
- if: matrix.pydantic-version == '2.4.2' && matrix.openai-version == '1.2.4'
89+
name: Static analysis with pyright (ignoring pydantic v1 and openai v0)
90+
run: |
91+
make type-pydantic-v2-openai-v1
92+
93+
- if: matrix.pydantic-version == '1.10.9' && matrix.openai-version == '1.2.4'
94+
name: Static analysis with mypy (ignoring pydantic v2 and openai v0)
95+
run: |
96+
make type-pydantic-v1-openai-v1
8597
8698
Pytests:
8799
runs-on: ubuntu-latest
@@ -92,6 +104,7 @@ jobs:
92104
# dependencies: ['dev', 'full']
93105
dependencies: ['full']
94106
pydantic-version: ['1.10.9', '2.4.2']
107+
openai-version: ['0.28.1', '1.2.4']
95108
steps:
96109
- uses: actions/checkout@v2
97110
- name: Set up Python ${{ matrix.python-version }}
@@ -103,15 +116,16 @@ jobs:
103116
uses: actions/cache@v3
104117
with:
105118
path: ~/.cache/pypoetry
106-
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}
119+
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}-${{ matrix.openai-version }}
107120

108121
- name: Install Poetry
109122
uses: snok/install-poetry@v1
110123

111124
- name: Install Dependencies
112125
run: |
113126
make ${{ matrix.dependencies }}
114-
python -m pip install pydantic==${{ matrix.pydantic-version }}
127+
poetry run pip install pydantic==${{ matrix.pydantic-version }}
128+
poetry run pip install openai==${{ matrix.openai-version }}
115129
116130
- name: Run Pytests
117131
run: |

Makefile

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,23 @@ autoformat:
88
type:
99
poetry run pyright guardrails/
1010

11-
type-pydantic-v1:
12-
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py"]}' > pyrightconfig.json
11+
type-pydantic-v1-openai-v0:
12+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
1313
poetry run pyright guardrails/
1414
rm pyrightconfig.json
1515

16-
type-pydantic-v2:
17-
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py"]}' > pyrightconfig.json
16+
type-pydantic-v1-openai-v1:
17+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
18+
poetry run pyright guardrails/
19+
rm pyrightconfig.json
20+
21+
type-pydantic-v2-openai-v0:
22+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
23+
poetry run pyright guardrails/
24+
rm pyrightconfig.json
25+
26+
type-pydantic-v2-openai-v1:
27+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
1828
poetry run pyright guardrails/
1929
rm pyrightconfig.json
2030

guardrails/applications/text2sql.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
from string import Template
55
from typing import Callable, Dict, Optional, Type
66

7-
import openai
8-
97
from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
108
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
119
from guardrails.guard import Guard
10+
from guardrails.utils.openai_utils import get_static_openai_create_func
1211
from guardrails.utils.sql_utils import create_sql_driver
1312
from guardrails.vectordb import Faiss, VectorDBBase
1413

@@ -71,7 +70,7 @@ def __init__(
7170
rail_params: Optional[Dict] = None,
7271
example_formatter: Callable = example_formatter,
7372
reask_prompt: str = REASK_PROMPT,
74-
llm_api: Callable = openai.Completion.create,
73+
llm_api: Optional[Callable] = None,
7574
llm_api_kwargs: Optional[Dict] = None,
7675
num_relevant_examples: int = 2,
7776
):
@@ -88,6 +87,8 @@ def __init__(
8887
example_formatter: Fn to format examples. Defaults to example_formatter.
8988
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
9089
"""
90+
if llm_api is None:
91+
llm_api = get_static_openai_create_func()
9192

9293
self.example_formatter = example_formatter
9394
self.llm_api = llm_api
@@ -185,9 +186,10 @@ def __call__(self, text: str) -> Optional[str]:
185186
"Async API is not supported in Text2SQL application. "
186187
"Please use a synchronous API."
187188
)
188-
189+
if self.llm_api is None:
190+
return None
189191
try:
190-
output = self.guard(
192+
return self.guard(
191193
self.llm_api,
192194
prompt_params={
193195
"nl_instruction": text,
@@ -201,6 +203,4 @@ def __call__(self, text: str) -> Optional[str]:
201203
"generated_sql"
202204
]
203205
except TypeError:
204-
output = None
205-
206-
return output
206+
return None

guardrails/embedding.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import os
21
from abc import ABC, abstractmethod
32
from functools import cached_property
43
from itertools import islice
54
from typing import Callable, List, Optional
65

7-
import openai
6+
from guardrails.utils.openai_utils import OpenAIClient
87

98

109
class EmbeddingBase(ABC):
@@ -114,9 +113,9 @@ def output_dim(self) -> int:
114113
class OpenAIEmbedding(EmbeddingBase):
115114
def __init__(
116115
self,
117-
model: Optional[str] = "text-embedding-ada-002",
118-
encoding_name: Optional[str] = "cl100k_base",
119-
max_tokens: Optional[int] = 8191,
116+
model: str = "text-embedding-ada-002",
117+
encoding_name: str = "cl100k_base",
118+
max_tokens: int = 8191,
120119
api_key: Optional[str] = None,
121120
api_base: Optional[str] = None,
122121
):
@@ -137,15 +136,14 @@ def embed_query(self, query: str) -> List[float]:
137136
return resp[0]
138137

139138
def _get_embedding(self, texts: List[str]) -> List[List[float]]:
140-
api_key = (
141-
self.api_key
142-
if self.api_key is not None
143-
else os.environ.get("OPENAI_API_KEY")
139+
client = OpenAIClient(
140+
api_key=self.api_key,
141+
api_base=self.api_base,
144142
)
145-
resp = openai.Embedding.create(
146-
api_key=api_key, model=self._model, input=texts, api_base=self.api_base
143+
return client.create_embedding(
144+
model=self._model,
145+
input=texts,
147146
)
148-
return [r["embedding"] for r in resp["data"]] # type: ignore
149147

150148
@property
151149
def output_dim(self) -> int:

0 commit comments

Comments
 (0)