1- from typing import Callable , Dict , Iterable , List , Optional , Union , cast
2- import warnings
1+ from typing import Any , Callable , Dict , Generic , Iterable , List , Optional , Union , cast
32from typing_extensions import deprecated
43
5- from guardrails .classes .execution .guard_execution_options import GuardExecutionOptions
64from guardrails .classes .output_type import OT , OutputTypes
75from guardrails .classes .validation_outcome import ValidationOutcome
6+ from guardrails .classes .validation .validator_reference import ValidatorReference
87
98from guardrails import Guard
10- from nemoguardrails import LLMRails
119
12- from guardrails .formatters import get_formatter
1310from guardrails .formatters .base_formatter import BaseFormatter
14- from guardrails .schema .pydantic_schema import pydantic_model_to_schema
1511from guardrails .types .pydantic import ModelOrListOfModels
1612
17- from guardrails .stores .context import (
18- Tracer
19- )
13+ from guardrails .stores .context import Tracer
14+
15+ try :
16+ from nemoguardrails import LLMRails
17+ except ImportError :
18+ raise ImportError (
19+ "Could not import nemoguardrails, please install it with "
20+ "`pip install nemoguardrails`."
21+ )
2022
2123
22- class NemoguardrailsGuard (Guard ):
24+ class NemoguardrailsGuard (Guard , Generic [ OT ] ):
2325 def __init__ (
2426 self ,
2527 nemorails : LLMRails ,
@@ -30,7 +32,11 @@ def __init__(
3032 self ._nemorails = nemorails
3133
3234 def __call__ (
33- self , llm_api : Optional [Callable ] = None , generate_kwargs : Optional [Dict ] = None , * args , ** kwargs
35+ self ,
36+ llm_api : Optional [Callable ] = None ,
37+ generate_kwargs : Optional [Dict ] = None ,
38+ * args ,
39+ ** kwargs ,
3440 ) -> Union [ValidationOutcome [OT ], Iterable [ValidationOutcome [OT ]]]:
3541 # peel llm_api off of kwargs
3642 llm_api = kwargs .pop ("llm_api" , None )
@@ -54,114 +60,61 @@ def __call__(
5460 )
5561
5662 def _custom_nemo_callable (* args , ** kwargs ):
57- return self ._custom_nemo_callable (* args , generate_kwargs = generate_kwargs , ** kwargs )
63+ return self ._custom_nemo_callable (
64+ * args , generate_kwargs = generate_kwargs , ** kwargs
65+ )
5866
5967 return super ().__call__ (llm_api = _custom_nemo_callable , * args , ** kwargs )
6068
6169 @classmethod
62- def from_pydantic (
70+ def _init_guard_for_cls_method (
6371 cls ,
72+ * ,
73+ name : Optional [str ] = None ,
74+ description : Optional [str ] = None ,
75+ validators : Optional [List [ValidatorReference ]] = None ,
76+ output_schema : Optional [Dict [str , Any ]] = None ,
6477 nemorails : LLMRails ,
78+ ** kwargs ,
79+ ):
80+ return cls (
81+ nemorails ,
82+ name = name ,
83+ description = description ,
84+ output_schema = output_schema ,
85+ validators = validators ,
86+ )
87+
88+ @classmethod
89+ def for_pydantic (
90+ cls ,
6591 output_class : ModelOrListOfModels ,
92+ nemorails : LLMRails ,
6693 * ,
67- prompt : Optional [str ] = None ,
68- instructions : Optional [str ] = None ,
6994 num_reasks : Optional [int ] = None ,
70- reask_prompt : Optional [str ] = None ,
71- reask_instructions : Optional [str ] = None ,
7295 reask_messages : Optional [List [Dict ]] = None ,
7396 messages : Optional [List [Dict ]] = None ,
7497 tracer : Optional [Tracer ] = None ,
7598 name : Optional [str ] = None ,
7699 description : Optional [str ] = None ,
77100 output_formatter : Optional [Union [str , BaseFormatter ]] = None ,
101+ ** kwargs ,
78102 ):
79- """Create a Guard instance using a Pydantic model to specify the output
80- schema.
81-
82- Args:
83- output_class: (Union[Type[BaseModel], List[Type[BaseModel]]]): The pydantic model that describes
84- the desired structure of the output.
85- prompt (str, optional): The prompt used to generate the string. Defaults to None.
86- instructions (str, optional): Instructions for chat models. Defaults to None.
87- reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None.
88- reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None.
89- reask_messages (List[Dict], optional): A list of messages to use during reasks. Defaults to None.
90- num_reasks (int, optional): The max times to re-ask the LLM if validation fails. Deprecated
91- tracer (Tracer, optional): An OpenTelemetry tracer to use for metrics and traces. Defaults to None.
92- name (str, optional): A unique name for this Guard. Defaults to `gr-` + the object id.
93- description (str, optional): A description for this Guard. Defaults to None.
94- output_formatter (str | Formatter, optional): 'none' (default), 'jsonformer', or a Guardrails Formatter.
95- """ # noqa
96-
97- if num_reasks :
98- warnings .warn (
99- "Setting num_reasks during initialization is deprecated"
100- " and will be removed in 0.6.x!"
101- "We recommend setting num_reasks when calling guard()"
102- " or guard.parse() instead."
103- "If you insist on setting it at the Guard level,"
104- " use 'Guard.configure()'." ,
105- DeprecationWarning ,
106- )
107-
108- if reask_instructions :
109- warnings .warn (
110- "reask_instructions is deprecated and will be removed in 0.6.x!"
111- "Please be prepared to set reask_messages instead." ,
112- DeprecationWarning ,
113- )
114- if reask_prompt :
115- warnings .warn (
116- "reask_prompt is deprecated and will be removed in 0.6.x!"
117- "Please be prepared to set reask_messages instead." ,
118- DeprecationWarning ,
119- )
120-
121- # We have to set the tracer in the ContextStore before the Rail,
122- # and therefore the Validators, are initialized
123- cls ._set_tracer (cls , tracer ) # type: ignore
124-
125- schema = pydantic_model_to_schema (output_class )
126- exec_opts = GuardExecutionOptions (
127- prompt = prompt ,
128- instructions = instructions ,
129- reask_prompt = reask_prompt ,
130- reask_instructions = reask_instructions ,
131- reask_messages = reask_messages ,
103+ guard = super ().for_pydantic (
104+ output_class ,
105+ num_reasks = num_reasks ,
132106 messages = messages ,
133- )
134-
135- # TODO: This is the only line that's changed vs the parent Guard class
136- # Find a way to refactor this
137- guard = cls (
138- nemorails = nemorails ,
107+ reask_messages = reask_messages ,
108+ tracer = tracer ,
139109 name = name ,
140110 description = description ,
141- output_schema = schema . json_schema ,
142- validators = schema . validators ,
111+ output_formatter = output_formatter ,
112+ nemorails = nemorails ,
143113 )
144- if schema . output_type == OutputTypes .LIST :
145- guard = cast (Guard [List ], guard )
114+ if guard . _output_type == OutputTypes .LIST :
115+ return cast (NemoguardrailsGuard [List ], guard )
146116 else :
147- guard = cast (Guard [Dict ], guard )
148- guard .configure (num_reasks = num_reasks , tracer = tracer )
149- guard ._validator_map = schema .validator_map
150- guard ._exec_opts = exec_opts
151- guard ._output_type = schema .output_type
152- guard ._base_model = output_class
153- if isinstance (output_formatter , str ):
154- if isinstance (output_class , list ):
155- raise Exception ("""Root-level arrays are not supported with the
156- jsonformer argument, but can be used with other json generation methods.
157- Omit the output_formatter argument to use the other methods.""" )
158- output_formatter = get_formatter (
159- output_formatter ,
160- schema = output_class .model_json_schema (), # type: ignore
161- )
162- guard ._output_formatter = output_formatter
163- guard ._fill_validators ()
164- return guard
117+ return cast (NemoguardrailsGuard [Dict ], guard )
165118
166119 # create the callable
167120 def _custom_nemo_callable (self , * args , generate_kwargs , ** kwargs ):
@@ -171,13 +124,13 @@ def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
171124 # msg_history, messages, prompt, and instruction all may or may not be present.
172125 # if none of them are present, raise an error
173126 # if messages is present, use that
174- # if msg_history is present, use
127+ # if msg_history is present, use
175128
176129 msg_history = kwargs .pop ("msg_history" , None )
177130 messages = kwargs .pop ("messages" , None )
178131 prompt = kwargs .pop ("prompt" , None )
179132 instructions = kwargs .pop ("instructions" , None )
180-
133+
181134 if msg_history is not None and messages is None :
182135 messages = msg_history
183136
@@ -188,30 +141,52 @@ def _custom_nemo_callable(self, *args, generate_kwargs, **kwargs):
188141 if prompt is not None :
189142 messages .append ({"role" : "system" , "content" : prompt })
190143
191- if messages is [] or messages is None :
192- raise ValueError ("messages, prompt, or instructions should be passed during a call." )
193-
144+ if messages == [] or messages is None :
145+ raise ValueError (
146+ "messages, prompt, or instructions should be passed during a call."
147+ )
148+
194149 # kwargs["messages"] = messages
195150
196151 # return (self._nemorails.generate(**kwargs))["content"] # type: ignore
197152 if not generate_kwargs :
198153 generate_kwargs = {}
199- return (self ._nemorails .generate (messages = messages , ** generate_kwargs ))["content" ] # type: ignore
154+ return (self ._nemorails .generate (messages = messages , ** generate_kwargs ))[ # type: ignore
155+ "content"
156+ ]
200157
201158 @deprecated (
202- "This method has been deprecated. Please use the main constructor `NemoGuardrailsGuard(nemorails=nemorails)` or the `from_pydantic` method." ,
159+ "Use `for_rail_string` instead. This method will be removed in 0.6.x." ,
160+ category = None ,
203161 )
162+ @classmethod
204163 def from_rail_string (cls , * args , ** kwargs ):
205164 raise NotImplementedError ("""\
206165 `from_rail_string` is not implemented for NemoguardrailsGuard.
207166We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
167+ or the `from_pydantic` method.""" )
168+
169+ @classmethod
170+ def for_rail_string (cls , * args , ** kwargs ):
171+ raise NotImplementedError ("""\
172+ `for_rail_string` is not implemented for NemoguardrailsGuard.
173+ We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
208174or the `from_pydantic` method.""" )
209175
210176 @deprecated (
211- "This method has been deprecated. Please use the main constructor `NemoGuardrailsGuard(nemorails=nemorails)` or the `from_pydantic` method." ,
177+ "Use `for_rail` instead. This method will be removed in 0.6.x." ,
178+ category = None ,
212179 )
180+ @classmethod
213181 def from_rail (cls , * args , ** kwargs ):
214182 raise NotImplementedError ("""\
215183 `from_rail` is not implemented for NemoguardrailsGuard.
216184We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
185+ or the `from_pydantic` method.""" )
186+
187+ @classmethod
188+ def for_rail (cls , * args , ** kwargs ):
189+ raise NotImplementedError ("""\
190+ `for_rail` is not implemented for NemoguardrailsGuard.
191+ We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)`
217192or the `from_pydantic` method.""" )
0 commit comments