Skip to content

Commit a1ab98c

Browse files
committed
cleanup, enable async support
1 parent 060143e commit a1ab98c

File tree

3 files changed

+94
-47
lines changed

3 files changed

+94
-47
lines changed

guardrails/async_guard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from builtins import id as object_id
22
import contextvars
33
import inspect
4+
from guardrails.formatters.base_formatter import BaseFormatter
45
from opentelemetry import context as otel_context
56
from typing import (
67
Any,
@@ -99,6 +100,8 @@ def for_pydantic(
99100
tracer: Optional[Tracer] = None,
100101
name: Optional[str] = None,
101102
description: Optional[str] = None,
103+
output_formatter: Optional[Union[str, BaseFormatter]] = None,
104+
**kwargs,
102105
):
103106
guard = super().for_pydantic(
104107
output_class,
@@ -108,6 +111,8 @@ def for_pydantic(
108111
tracer=tracer,
109112
name=name,
110113
description=description,
114+
output_formatter=output_formatter,
115+
**kwargs,
111116
)
112117
if guard._output_type == OutputTypes.LIST:
113118
return cast(AsyncGuard[List], guard)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from guardrails.integrations.nemoguardrails.nemoguardrails_guard import (
2+
NemoguardrailsGuard,
3+
AsyncNemoguardrailsGuard,
4+
)
5+
6+
__all__ = ["NemoguardrailsGuard", "AsyncNemoguardrailsGuard"]

guardrails/integrations/nemoguardrails/nemoguardrails_guard.py

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
1-
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Union, cast
1+
import inspect
2+
from functools import partial
3+
from typing import (
4+
Any,
5+
AsyncIterator,
6+
Awaitable,
7+
Callable,
8+
Dict,
9+
Generic,
10+
Iterable,
11+
List,
12+
Optional,
13+
Union,
14+
cast,
15+
)
216
from typing_extensions import deprecated
317

418
from guardrails.classes.output_type import OT, OutputTypes
519
from guardrails.classes.validation_outcome import ValidationOutcome
620
from guardrails.classes.validation.validator_reference import ValidatorReference
721

8-
from guardrails import Guard
22+
from guardrails import Guard, AsyncGuard
923

1024
from guardrails.formatters.base_formatter import BaseFormatter
1125
from guardrails.types.pydantic import ModelOrListOfModels
@@ -20,6 +34,17 @@
2034
"`pip install nemoguardrails`."
2135
)
2236

37+
try:
38+
import nest_asyncio
39+
40+
nest_asyncio.apply()
41+
import asyncio
42+
except ImportError:
43+
raise ImportError(
44+
"Could not import nest_asyncio, please install it with "
45+
"`pip install nest_asyncio`."
46+
)
47+
2348

2449
class NemoguardrailsGuard(Guard, Generic[OT]):
2550
def __init__(
@@ -30,6 +55,28 @@ def __init__(
3055
):
3156
super().__init__(*args, **kwargs)
3257
self._nemorails = nemorails
58+
self._generate = self._nemorails.generate
59+
60+
def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
61+
# .generate doesn't like temp
62+
kwargs.pop("temperature", None)
63+
64+
messages = kwargs.pop("messages", None)
65+
66+
if messages == [] or messages is None:
67+
raise ValueError("messages must be passed during a call.")
68+
69+
if not generate_kwargs:
70+
generate_kwargs = {}
71+
72+
response = self._generate(messages=messages, **generate_kwargs)
73+
74+
if inspect.iscoroutine(response):
75+
response = asyncio.run(response)
76+
77+
return response[ # type: ignore
78+
"content"
79+
]
3380

3481
def __call__(
3582
self,
@@ -59,12 +106,9 @@ def __call__(
59106
dictionaries, where each dictionary has a 'role' key and a 'content' key."""
60107
)
61108

62-
def _custom_nemo_callable(*args, **kwargs):
63-
return self._custom_nemo_callable(
64-
*args, generate_kwargs=generate_kwargs, **kwargs
65-
)
109+
llm_api = partial(self._custom_nemo_callable, generate_kwargs=generate_kwargs)
66110

67-
return super().__call__(llm_api=_custom_nemo_callable, *args, **kwargs)
111+
return super().__call__(llm_api=llm_api, *args, **kwargs)
68112

69113
@classmethod
70114
def _init_guard_for_cls_method(
@@ -89,8 +133,8 @@ def _init_guard_for_cls_method(
89133
def for_pydantic(
90134
cls,
91135
output_class: ModelOrListOfModels,
92-
nemorails: LLMRails,
93136
*,
137+
nemorails: LLMRails,
94138
num_reasks: Optional[int] = None,
95139
reask_messages: Optional[List[Dict]] = None,
96140
messages: Optional[List[Dict]] = None,
@@ -116,45 +160,6 @@ def for_pydantic(
116160
else:
117161
return cast(NemoguardrailsGuard[Dict], guard)
118162

119-
# create the callable
120-
def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
121-
# .generate doesn't like temp
122-
kwargs.pop("temperature", None)
123-
124-
# msg_history, messages, prompt, and instruction all may or may not be present.
125-
# if none of them are present, raise an error
126-
# if messages is present, use that
127-
# if msg_history is present, use
128-
129-
msg_history = kwargs.pop("msg_history", None)
130-
messages = kwargs.pop("messages", None)
131-
prompt = kwargs.pop("prompt", None)
132-
instructions = kwargs.pop("instructions", None)
133-
134-
if msg_history is not None and messages is None:
135-
messages = msg_history
136-
137-
if messages is None and msg_history is None:
138-
messages = []
139-
if instructions is not None:
140-
messages.append({"role": "system", "content": instructions})
141-
if prompt is not None:
142-
messages.append({"role": "system", "content": prompt})
143-
144-
if messages == [] or messages is None:
145-
raise ValueError(
146-
"messages, prompt, or instructions should be passed during a call."
147-
)
148-
149-
# kwargs["messages"] = messages
150-
151-
# return (self._nemorails.generate(**kwargs))["content"] # type: ignore
152-
if not generate_kwargs:
153-
generate_kwargs = {}
154-
return (self._nemorails.generate(messages=messages, **generate_kwargs))[ # type: ignore
155-
"content"
156-
]
157-
158163
@deprecated(
159164
"Use `for_rail_string` instead. This method will be removed in 0.6.x.",
160165
category=None,
@@ -190,3 +195,34 @@ def for_rail(cls, *args, **kwargs):
190195
`for_rail` is not implemented for NemoguardrailsGuard.
191196
We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
192197
or the `from_pydantic` method.""")
198+
199+
200+
class AsyncNemoguardrailsGuard(NemoguardrailsGuard, AsyncGuard, Generic[OT]):
201+
def __init__(
202+
self,
203+
nemorails: LLMRails,
204+
*args,
205+
**kwargs,
206+
):
207+
super().__init__(nemorails, *args, **kwargs)
208+
self._generate = self._nemorails.generate_async
209+
210+
async def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
211+
return super()._custom_nemo_callable(
212+
*args, generate_kwargs=generate_kwargs, **kwargs
213+
)
214+
215+
async def __call__( # type: ignore
216+
self,
217+
llm_api: Optional[Callable] = None,
218+
generate_kwargs: Optional[Dict] = None,
219+
*args,
220+
**kwargs,
221+
) -> Union[
222+
ValidationOutcome[OT],
223+
Awaitable[ValidationOutcome[OT]],
224+
AsyncIterator[ValidationOutcome[OT]],
225+
]:
226+
return await super().__call__(
227+
llm_api=llm_api, generate_kwargs=generate_kwargs, *args, **kwargs
228+
) # type: ignore

0 commit comments

Comments
 (0)