Skip to content

Commit db5544f

Browse files
committed
0.6.0 update
1 parent a4190f3 commit db5544f

File tree

2 files changed

+105
-109
lines changed

2 files changed

+105
-109
lines changed

guardrails/guard.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def _for_rail_schema(
380380
name: Optional[str] = None,
381381
description: Optional[str] = None,
382382
):
383-
guard = cls(
383+
guard = cls._init_guard_for_cls_method(
384384
name=name,
385385
description=description,
386386
output_schema=schema.json_schema,
@@ -526,6 +526,25 @@ def for_rail_string(
526526
def from_pydantic(cls, output_class: ModelOrListOfModels, *args, **kwargs):
527527
return cls.for_pydantic(output_class, **kwargs)
528528

529+
@classmethod
530+
def _init_guard_for_cls_method(
531+
cls,
532+
*,
533+
id: Optional[str] = None,
534+
name: Optional[str] = None,
535+
description: Optional[str] = None,
536+
validators: Optional[List[ValidatorReference]] = None,
537+
output_schema: Optional[Dict[str, Any]] = None,
538+
**kwargs,
539+
):
540+
return cls(
541+
id=id,
542+
name=name,
543+
description=description,
544+
output_schema=output_schema,
545+
validators=validators,
546+
)
547+
529548
@classmethod
530549
def for_pydantic(
531550
cls,
@@ -538,6 +557,7 @@ def for_pydantic(
538557
name: Optional[str] = None,
539558
description: Optional[str] = None,
540559
output_formatter: Optional[Union[str, BaseFormatter]] = None,
560+
**kwargs,
541561
):
542562
"""Create a Guard instance using a Pydantic model to specify the output
543563
schema.
@@ -574,11 +594,12 @@ def for_pydantic(
574594
reask_messages=reask_messages,
575595
messages=messages,
576596
)
577-
guard = cls(
597+
guard = cls._init_guard_for_cls_method(
578598
name=name,
579599
description=description,
580600
output_schema=schema.json_schema,
581601
validators=schema.validators,
602+
**kwargs,
582603
)
583604
if schema.output_type == OutputTypes.LIST:
584605
guard = cast(Guard[List], guard)
@@ -1306,7 +1327,7 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional["Guard"]:
13061327
i_guard.output_schema.to_dict() if i_guard.output_schema else None
13071328
)
13081329

1309-
guard = cls(
1330+
guard = cls._init_guard_for_cls_method(
13101331
id=i_guard.id,
13111332
name=i_guard.name,
13121333
description=i_guard.description,
Lines changed: 81 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1-
from typing import Callable, Dict, Iterable, List, Optional, Union, cast
2-
import warnings
1+
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Union, cast
32
from typing_extensions import deprecated
43

5-
from guardrails.classes.execution.guard_execution_options import GuardExecutionOptions
64
from guardrails.classes.output_type import OT, OutputTypes
75
from guardrails.classes.validation_outcome import ValidationOutcome
6+
from guardrails.classes.validation.validator_reference import ValidatorReference
87

98
from guardrails import Guard
10-
from nemoguardrails import LLMRails
119

12-
from guardrails.formatters import get_formatter
1310
from guardrails.formatters.base_formatter import BaseFormatter
14-
from guardrails.schema.pydantic_schema import pydantic_model_to_schema
1511
from guardrails.types.pydantic import ModelOrListOfModels
1612

17-
from guardrails.stores.context import (
18-
Tracer
19-
)
13+
from guardrails.stores.context import Tracer
14+
15+
try:
16+
from nemoguardrails import LLMRails
17+
except ImportError:
18+
raise ImportError(
19+
"Could not import nemoguardrails, please install it with "
20+
"`pip install nemoguardrails`."
21+
)
2022

2123

22-
class NemoguardrailsGuard(Guard):
24+
class NemoguardrailsGuard(Guard, Generic[OT]):
2325
def __init__(
2426
self,
2527
nemorails: LLMRails,
@@ -30,7 +32,11 @@ def __init__(
3032
self._nemorails = nemorails
3133

3234
def __call__(
33-
self, llm_api: Optional[Callable] = None, generate_kwargs: Optional[Dict] = None, *args, **kwargs
35+
self,
36+
llm_api: Optional[Callable] = None,
37+
generate_kwargs: Optional[Dict] = None,
38+
*args,
39+
**kwargs,
3440
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
3541
# peel llm_api off of kwargs
3642
llm_api = kwargs.pop("llm_api", None)
@@ -54,114 +60,61 @@ def __call__(
5460
)
5561

5662
def _custom_nemo_callable(*args, **kwargs):
57-
return self._custom_nemo_callable(*args, generate_kwargs=generate_kwargs, **kwargs)
63+
return self._custom_nemo_callable(
64+
*args, generate_kwargs=generate_kwargs, **kwargs
65+
)
5866

5967
return super().__call__(llm_api=_custom_nemo_callable, *args, **kwargs)
6068

6169
@classmethod
62-
def from_pydantic(
70+
def _init_guard_for_cls_method(
6371
cls,
72+
*,
73+
name: Optional[str] = None,
74+
description: Optional[str] = None,
75+
validators: Optional[List[ValidatorReference]] = None,
76+
output_schema: Optional[Dict[str, Any]] = None,
6477
nemorails: LLMRails,
78+
**kwargs,
79+
):
80+
return cls(
81+
nemorails,
82+
name=name,
83+
description=description,
84+
output_schema=output_schema,
85+
validators=validators,
86+
)
87+
88+
@classmethod
89+
def for_pydantic(
90+
cls,
6591
output_class: ModelOrListOfModels,
92+
nemorails: LLMRails,
6693
*,
67-
prompt: Optional[str] = None,
68-
instructions: Optional[str] = None,
6994
num_reasks: Optional[int] = None,
70-
reask_prompt: Optional[str] = None,
71-
reask_instructions: Optional[str] = None,
7295
reask_messages: Optional[List[Dict]] = None,
7396
messages: Optional[List[Dict]] = None,
7497
tracer: Optional[Tracer] = None,
7598
name: Optional[str] = None,
7699
description: Optional[str] = None,
77100
output_formatter: Optional[Union[str, BaseFormatter]] = None,
101+
**kwargs,
78102
):
79-
"""Create a Guard instance using a Pydantic model to specify the output
80-
schema.
81-
82-
Args:
83-
output_class: (Union[Type[BaseModel], List[Type[BaseModel]]]): The pydantic model that describes
84-
the desired structure of the output.
85-
prompt (str, optional): The prompt used to generate the string. Defaults to None.
86-
instructions (str, optional): Instructions for chat models. Defaults to None.
87-
reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None.
88-
reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None.
89-
reask_messages (List[Dict], optional): A list of messages to use during reasks. Defaults to None.
90-
num_reasks (int, optional): The max times to re-ask the LLM if validation fails. Deprecated
91-
tracer (Tracer, optional): An OpenTelemetry tracer to use for metrics and traces. Defaults to None.
92-
name (str, optional): A unique name for this Guard. Defaults to `gr-` + the object id.
93-
description (str, optional): A description for this Guard. Defaults to None.
94-
output_formatter (str | Formatter, optional): 'none' (default), 'jsonformer', or a Guardrails Formatter.
95-
""" # noqa
96-
97-
if num_reasks:
98-
warnings.warn(
99-
"Setting num_reasks during initialization is deprecated"
100-
" and will be removed in 0.6.x!"
101-
"We recommend setting num_reasks when calling guard()"
102-
" or guard.parse() instead."
103-
"If you insist on setting it at the Guard level,"
104-
" use 'Guard.configure()'.",
105-
DeprecationWarning,
106-
)
107-
108-
if reask_instructions:
109-
warnings.warn(
110-
"reask_instructions is deprecated and will be removed in 0.6.x!"
111-
"Please be prepared to set reask_messages instead.",
112-
DeprecationWarning,
113-
)
114-
if reask_prompt:
115-
warnings.warn(
116-
"reask_prompt is deprecated and will be removed in 0.6.x!"
117-
"Please be prepared to set reask_messages instead.",
118-
DeprecationWarning,
119-
)
120-
121-
# We have to set the tracer in the ContextStore before the Rail,
122-
# and therefore the Validators, are initialized
123-
cls._set_tracer(cls, tracer) # type: ignore
124-
125-
schema = pydantic_model_to_schema(output_class)
126-
exec_opts = GuardExecutionOptions(
127-
prompt=prompt,
128-
instructions=instructions,
129-
reask_prompt=reask_prompt,
130-
reask_instructions=reask_instructions,
131-
reask_messages=reask_messages,
103+
guard = super().for_pydantic(
104+
output_class,
105+
num_reasks=num_reasks,
132106
messages=messages,
133-
)
134-
135-
# TODO: This is the only line that's changed vs the parent Guard class
136-
# Find a way to refactor this
137-
guard = cls(
138-
nemorails=nemorails,
107+
reask_messages=reask_messages,
108+
tracer=tracer,
139109
name=name,
140110
description=description,
141-
output_schema=schema.json_schema,
142-
validators=schema.validators,
111+
output_formatter=output_formatter,
112+
nemorails=nemorails,
143113
)
144-
if schema.output_type == OutputTypes.LIST:
145-
guard = cast(Guard[List], guard)
114+
if guard._output_type == OutputTypes.LIST:
115+
return cast(NemoguardrailsGuard[List], guard)
146116
else:
147-
guard = cast(Guard[Dict], guard)
148-
guard.configure(num_reasks=num_reasks, tracer=tracer)
149-
guard._validator_map = schema.validator_map
150-
guard._exec_opts = exec_opts
151-
guard._output_type = schema.output_type
152-
guard._base_model = output_class
153-
if isinstance(output_formatter, str):
154-
if isinstance(output_class, list):
155-
raise Exception("""Root-level arrays are not supported with the
156-
jsonformer argument, but can be used with other json generation methods.
157-
Omit the output_formatter argument to use the other methods.""")
158-
output_formatter = get_formatter(
159-
output_formatter,
160-
schema=output_class.model_json_schema(), # type: ignore
161-
)
162-
guard._output_formatter = output_formatter
163-
guard._fill_validators()
164-
return guard
117+
return cast(NemoguardrailsGuard[Dict], guard)
165118

166119
# create the callable
167120
def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
@@ -171,13 +124,13 @@ def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
171124
# msg_history, messages, prompt, and instruction all may or may not be present.
172125
# if none of them are present, raise an error
173126
# if messages is present, use that
174-
# if msg_history is present, use
127+
# if msg_history is present, use
175128

176129
msg_history = kwargs.pop("msg_history", None)
177130
messages = kwargs.pop("messages", None)
178131
prompt = kwargs.pop("prompt", None)
179132
instructions = kwargs.pop("instructions", None)
180-
133+
181134
if msg_history is not None and messages is None:
182135
messages = msg_history
183136

@@ -188,30 +141,52 @@ def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
188141
if prompt is not None:
189142
messages.append({"role": "system", "content": prompt})
190143

191-
if messages is [] or messages is None:
192-
raise ValueError("messages, prompt, or instructions should be passed during a call.")
193-
144+
if messages == [] or messages is None:
145+
raise ValueError(
146+
"messages, prompt, or instructions should be passed during a call."
147+
)
148+
194149
# kwargs["messages"] = messages
195150

196151
# return (self._nemorails.generate(**kwargs))["content"] # type: ignore
197152
if not generate_kwargs:
198153
generate_kwargs = {}
199-
return (self._nemorails.generate(messages=messages, **generate_kwargs))["content"] # type: ignore
154+
return (self._nemorails.generate(messages=messages, **generate_kwargs))[ # type: ignore
155+
"content"
156+
]
200157

201158
@deprecated(
202-
"This method has been deprecated. Please use the main constructor `NemoGuardrailsGuard(nemorails=nemorails)` or the `from_pydantic` method.",
159+
"Use `for_rail_string` instead. This method will be removed in 0.6.x.",
160+
category=None,
203161
)
162+
@classmethod
204163
def from_rail_string(cls, *args, **kwargs):
205164
raise NotImplementedError("""\
206165
`from_rail_string` is not implemented for NemoguardrailsGuard.
207166
We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
167+
or the `from_pydantic` method.""")
168+
169+
@classmethod
170+
def for_rail_string(cls, *args, **kwargs):
171+
raise NotImplementedError("""\
172+
`for_rail_string` is not implemented for NemoguardrailsGuard.
173+
We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
208174
or the `from_pydantic` method.""")
209175

210176
@deprecated(
211-
"This method has been deprecated. Please use the main constructor `NemoGuardrailsGuard(nemorails=nemorails)` or the `from_pydantic` method.",
177+
"Use `for_rail` instead. This method will be removed in 0.6.x.",
178+
category=None,
212179
)
180+
@classmethod
213181
def from_rail(cls, *args, **kwargs):
214182
raise NotImplementedError("""\
215183
`from_rail` is not implemented for NemoguardrailsGuard.
216184
We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
185+
or the `from_pydantic` method.""")
186+
187+
@classmethod
188+
def for_rail(cls, *args, **kwargs):
189+
raise NotImplementedError("""\
190+
`for_rail` is not implemented for NemoguardrailsGuard.
191+
We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
217192
or the `from_pydantic` method.""")

0 commit comments

Comments
 (0)