Skip to content

Commit 9cae16f

Browse files
committed
Merge branch 'main' into feature/claude
2 parents 7781485 + f1f511c commit 9cae16f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1050
-473
lines changed

.github/workflows/ci.yml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
matrix:
5252
python-version: ['3.8', '3.9', '3.10', '3.11']
5353
pydantic-version: ['1.10.9', '2.4.2']
54+
openai-version: ['0.28.1', '1.2.4']
5455
steps:
5556
- uses: actions/checkout@v2
5657
- name: Set up Python ${{ matrix.python-version }}
@@ -72,16 +73,27 @@ jobs:
7273
run: |
7374
make full
7475
poetry run pip install pydantic==${{ matrix.pydantic-version }}
76+
poetry run pip install openai==${{ matrix.openai-version }}
7577
76-
- if: matrix.pydantic-version == '2.4.2'
77-
name: Static analysis with pyright (ignoring pydantic v1)
78+
- if: matrix.pydantic-version == '2.4.2' && matrix.openai-version == '0.28.1'
79+
name: Static analysis with pyright (ignoring pydantic v1 and openai v1)
7880
run: |
79-
make type-pydantic-v2
81+
make type-pydantic-v2-openai-v0
8082
81-
- if: matrix.pydantic-version == '1.10.9'
82-
name: Static analysis with mypy (ignoring pydantic v2)
83+
- if: matrix.pydantic-version == '1.10.9' && matrix.openai-version == '0.28.1'
84+
name: Static analysis with mypy (ignoring pydantic v2 and openai v1)
8385
run: |
84-
make type-pydantic-v1
86+
make type-pydantic-v1-openai-v0
87+
88+
- if: matrix.pydantic-version == '2.4.2' && matrix.openai-version == '1.2.4'
89+
name: Static analysis with pyright (ignoring pydantic v1 and openai v0)
90+
run: |
91+
make type-pydantic-v2-openai-v1
92+
93+
- if: matrix.pydantic-version == '1.10.9' && matrix.openai-version == '1.2.4'
94+
name: Static analysis with mypy (ignoring pydantic v2 and openai v0)
95+
run: |
96+
make type-pydantic-v1-openai-v1
8597
8698
Pytests:
8799
runs-on: ubuntu-latest
@@ -92,6 +104,7 @@ jobs:
92104
# dependencies: ['dev', 'full']
93105
dependencies: ['full']
94106
pydantic-version: ['1.10.9', '2.4.2']
107+
openai-version: ['0.28.1', '1.2.4']
95108
steps:
96109
- uses: actions/checkout@v2
97110
- name: Set up Python ${{ matrix.python-version }}
@@ -103,15 +116,16 @@ jobs:
103116
uses: actions/cache@v3
104117
with:
105118
path: ~/.cache/pypoetry
106-
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}
119+
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}-${{ matrix.openai-version }}
107120

108121
- name: Install Poetry
109122
uses: snok/install-poetry@v1
110123

111124
- name: Install Dependencies
112125
run: |
113126
make ${{ matrix.dependencies }}
114-
python -m pip install pydantic==${{ matrix.pydantic-version }}
127+
poetry run pip install pydantic==${{ matrix.pydantic-version }}
128+
poetry run pip install openai==${{ matrix.openai-version }}
115129
116130
- name: Run Pytests
117131
run: |

Makefile

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,23 @@ autoformat:
88
type:
99
poetry run pyright guardrails/
1010

11-
type-pydantic-v1:
12-
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py"]}' > pyrightconfig.json
11+
type-pydantic-v1-openai-v0:
12+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
1313
poetry run pyright guardrails/
1414
rm pyrightconfig.json
1515

16-
type-pydantic-v2:
17-
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py"]}' > pyrightconfig.json
16+
type-pydantic-v1-openai-v1:
17+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
18+
poetry run pyright guardrails/
19+
rm pyrightconfig.json
20+
21+
type-pydantic-v2-openai-v0:
22+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
23+
poetry run pyright guardrails/
24+
rm pyrightconfig.json
25+
26+
type-pydantic-v2-openai-v1:
27+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
1828
poetry run pyright guardrails/
1929
rm pyrightconfig.json
2030

guardrails/applications/text2sql.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
from string import Template
55
from typing import Callable, Dict, Optional, Type
66

7-
import openai
8-
97
from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
108
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
119
from guardrails.guard import Guard
10+
from guardrails.utils.openai_utils import get_static_openai_create_func
1211
from guardrails.utils.sql_utils import create_sql_driver
1312
from guardrails.vectordb import Faiss, VectorDBBase
1413

@@ -71,7 +70,7 @@ def __init__(
7170
rail_params: Optional[Dict] = None,
7271
example_formatter: Callable = example_formatter,
7372
reask_prompt: str = REASK_PROMPT,
74-
llm_api: Callable = openai.Completion.create,
73+
llm_api: Optional[Callable] = None,
7574
llm_api_kwargs: Optional[Dict] = None,
7675
num_relevant_examples: int = 2,
7776
):
@@ -88,6 +87,8 @@ def __init__(
8887
example_formatter: Fn to format examples. Defaults to example_formatter.
8988
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
9089
"""
90+
if llm_api is None:
91+
llm_api = get_static_openai_create_func()
9192

9293
self.example_formatter = example_formatter
9394
self.llm_api = llm_api
@@ -185,9 +186,10 @@ def __call__(self, text: str) -> Optional[str]:
185186
"Async API is not supported in Text2SQL application. "
186187
"Please use a synchronous API."
187188
)
188-
189+
if self.llm_api is None:
190+
return None
189191
try:
190-
output = self.guard(
192+
return self.guard(
191193
self.llm_api,
192194
prompt_params={
193195
"nl_instruction": text,
@@ -201,6 +203,4 @@ def __call__(self, text: str) -> Optional[str]:
201203
"generated_sql"
202204
]
203205
except TypeError:
204-
output = None
205-
206-
return output
206+
return None

guardrails/datatypes.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from lxml import etree as ET
1212
from typing_extensions import Self
1313

14-
from guardrails.formatattr import FormatAttr
1514
from guardrails.utils.casting_utils import to_float, to_int, to_string
1615
from guardrails.utils.xml_utils import cast_xml_to_string
1716
from guardrails.validator_base import Validator, ValidatorSpec
17+
from guardrails.validatorsattr import ValidatorsAttr
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -62,20 +62,20 @@ class DataType:
6262
def __init__(
6363
self,
6464
children: Dict[str, Any],
65-
format_attr: FormatAttr,
65+
validators_attr: ValidatorsAttr,
6666
optional: bool,
6767
name: Optional[str],
6868
description: Optional[str],
6969
) -> None:
7070
self._children = children
71-
self.format_attr = format_attr
71+
self.validators_attr = validators_attr
7272
self.name = name
7373
self.description = description
7474
self.optional = optional
7575

7676
@property
7777
def validators(self) -> TypedList:
78-
return self.format_attr.validators
78+
return self.validators_attr.validators
7979

8080
def __repr__(self) -> str:
8181
return f"{self.__class__.__name__}({self._children})"
@@ -119,9 +119,9 @@ def set_children_from_xml(self, element: ET._Element):
119119
@classmethod
120120
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
121121
# TODO: don't want to pass strict through to DataType,
122-
# but need to pass it to FormatAttr.from_xml
122+
# but need to pass it to ValidatorsAttr.from_element
123123
# how to handle this?
124-
format_attr = FormatAttr.from_xml(element, cls.tag, strict)
124+
validators_attr = ValidatorsAttr.from_xml(element, cls.tag, strict)
125125

126126
is_optional = element.attrib.get("required", "true") == "false"
127127

@@ -133,7 +133,7 @@ def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
133133
if description is not None:
134134
description = cast_xml_to_string(description)
135135

136-
data_type = cls({}, format_attr, is_optional, name, description, **kwargs)
136+
data_type = cls({}, validators_attr, is_optional, name, description, **kwargs)
137137
data_type.set_children_from_xml(element)
138138
return data_type
139139

@@ -203,7 +203,7 @@ def from_string_rail(
203203
) -> Self:
204204
return cls(
205205
children={},
206-
format_attr=FormatAttr.from_validators(validators, cls.tag, strict),
206+
validators_attr=ValidatorsAttr.from_validators(validators, cls.tag, strict),
207207
optional=False,
208208
name=None,
209209
description=description,
@@ -267,12 +267,12 @@ class Date(ScalarType):
267267
def __init__(
268268
self,
269269
children: Dict[str, Any],
270-
format_attr: "FormatAttr",
270+
validators_attr: "ValidatorsAttr",
271271
optional: bool,
272272
name: Optional[str],
273273
description: Optional[str],
274274
) -> None:
275-
super().__init__(children, format_attr, optional, name, description)
275+
super().__init__(children, validators_attr, optional, name, description)
276276
self.date_format = None
277277

278278
def from_str(self, s: str) -> Optional[datetime.date]:
@@ -306,13 +306,13 @@ class Time(ScalarType):
306306
def __init__(
307307
self,
308308
children: Dict[str, Any],
309-
format_attr: "FormatAttr",
309+
validators_attr: "ValidatorsAttr",
310310
optional: bool,
311311
name: Optional[str],
312312
description: Optional[str],
313313
) -> None:
314314
self.time_format = "%H:%M:%S"
315-
super().__init__(children, format_attr, optional, name, description)
315+
super().__init__(children, validators_attr, optional, name, description)
316316

317317
def from_str(self, s: str) -> Optional[datetime.time]:
318318
"""Create a Time from a string."""
@@ -486,13 +486,13 @@ class Choice(NonScalarType):
486486
def __init__(
487487
self,
488488
children: Dict[str, Any],
489-
format_attr: "FormatAttr",
489+
validators_attr: "ValidatorsAttr",
490490
optional: bool,
491491
name: Optional[str],
492492
description: Optional[str],
493493
discriminator_key: str,
494494
) -> None:
495-
super().__init__(children, format_attr, optional, name, description)
495+
super().__init__(children, validators_attr, optional, name, description)
496496
self.discriminator_key = discriminator_key
497497

498498
@classmethod
@@ -548,12 +548,12 @@ class Case(NonScalarType):
548548
def __init__(
549549
self,
550550
children: Dict[str, Any],
551-
format_attr: "FormatAttr",
551+
validators_attr: "ValidatorsAttr",
552552
optional: bool,
553553
name: Optional[str],
554554
description: Optional[str],
555555
) -> None:
556-
super().__init__(children, format_attr, optional, name, description)
556+
super().__init__(children, validators_attr, optional, name, description)
557557

558558
def collect_validation(
559559
self,

guardrails/embedding.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import os
21
from abc import ABC, abstractmethod
32
from functools import cached_property
43
from itertools import islice
54
from typing import Callable, List, Optional
65

7-
import openai
6+
from guardrails.utils.openai_utils import OpenAIClient
87

98

109
class EmbeddingBase(ABC):
@@ -114,9 +113,9 @@ def output_dim(self) -> int:
114113
class OpenAIEmbedding(EmbeddingBase):
115114
def __init__(
116115
self,
117-
model: Optional[str] = "text-embedding-ada-002",
118-
encoding_name: Optional[str] = "cl100k_base",
119-
max_tokens: Optional[int] = 8191,
116+
model: str = "text-embedding-ada-002",
117+
encoding_name: str = "cl100k_base",
118+
max_tokens: int = 8191,
120119
api_key: Optional[str] = None,
121120
api_base: Optional[str] = None,
122121
):
@@ -137,15 +136,14 @@ def embed_query(self, query: str) -> List[float]:
137136
return resp[0]
138137

139138
def _get_embedding(self, texts: List[str]) -> List[List[float]]:
140-
api_key = (
141-
self.api_key
142-
if self.api_key is not None
143-
else os.environ.get("OPENAI_API_KEY")
139+
client = OpenAIClient(
140+
api_key=self.api_key,
141+
api_base=self.api_base,
144142
)
145-
resp = openai.Embedding.create(
146-
api_key=api_key, model=self._model, input=texts, api_base=self.api_base
143+
return client.create_embedding(
144+
model=self._model,
145+
input=texts,
147146
)
148-
return [r["embedding"] for r in resp["data"]] # type: ignore
149147

150148
@property
151149
def output_dim(self) -> int:

0 commit comments

Comments
 (0)