1- from  typing  import  Any , Callable , Dict , Generic , Iterable , List , Optional , Union , cast 
1+ import  inspect 
2+ from  functools  import  partial 
3+ from  typing  import  (
4+     Any ,
5+     AsyncIterator ,
6+     Awaitable ,
7+     Callable ,
8+     Dict ,
9+     Generic ,
10+     Iterable ,
11+     List ,
12+     Optional ,
13+     Union ,
14+     cast ,
15+ )
216from  typing_extensions  import  deprecated 
317
418from  guardrails .classes .output_type  import  OT , OutputTypes 
519from  guardrails .classes .validation_outcome  import  ValidationOutcome 
620from  guardrails .classes .validation .validator_reference  import  ValidatorReference 
721
8- from  guardrails  import  Guard 
22+ from  guardrails  import  Guard ,  AsyncGuard 
923
1024from  guardrails .formatters .base_formatter  import  BaseFormatter 
1125from  guardrails .types .pydantic  import  ModelOrListOfModels 
2034        "`pip install nemoguardrails`." 
2135    )
2236
37+ try :
38+     import  nest_asyncio 
39+ 
40+     nest_asyncio .apply ()
41+     import  asyncio 
42+ except  ImportError :
43+     raise  ImportError (
44+         "Could not import nest_asyncio, please install it with " 
45+         "`pip install nest_asyncio`." 
46+     )
47+ 
2348
2449class  NemoguardrailsGuard (Guard , Generic [OT ]):
2550    def  __init__ (
@@ -30,6 +55,28 @@ def __init__(
3055    ):
3156        super ().__init__ (* args , ** kwargs )
3257        self ._nemorails  =  nemorails 
58+         self ._generate  =  self ._nemorails .generate 
59+ 
60+     def  _custom_nemo_callable (self , * args , generate_kwargs , ** kwargs ):
61+         # .generate doesn't like temp 
62+         kwargs .pop ("temperature" , None )
63+ 
64+         messages  =  kwargs .pop ("messages" , None )
65+ 
66+         if  messages  ==  [] or  messages  is  None :
67+             raise  ValueError ("messages must be passed during a call." )
68+ 
69+         if  not  generate_kwargs :
70+             generate_kwargs  =  {}
71+ 
72+         response  =  self ._generate (messages = messages , ** generate_kwargs )
73+ 
74+         if  inspect .iscoroutine (response ):
75+             response  =  asyncio .run (response )
76+ 
77+         return  response [  # type: ignore 
78+             "content" 
79+         ]
3380
3481    def  __call__ (
3582        self ,
@@ -59,12 +106,9 @@ def __call__(
59106 dictionaries, where each dictionary has a 'role' key and a 'content' key.""" 
60107            )
61108
62-         def  _custom_nemo_callable (* args , ** kwargs ):
63-             return  self ._custom_nemo_callable (
64-                 * args , generate_kwargs = generate_kwargs , ** kwargs 
65-             )
109+         llm_api  =  partial (self ._custom_nemo_callable , generate_kwargs = generate_kwargs )
66110
67-         return  super ().__call__ (llm_api = _custom_nemo_callable , * args , ** kwargs )
111+         return  super ().__call__ (llm_api = llm_api , * args , ** kwargs )
68112
69113    @classmethod  
70114    def  _init_guard_for_cls_method (
@@ -89,8 +133,8 @@ def _init_guard_for_cls_method(
89133    def  for_pydantic (
90134        cls ,
91135        output_class : ModelOrListOfModels ,
92-         nemorails : LLMRails ,
93136        * ,
137+         nemorails : LLMRails ,
94138        num_reasks : Optional [int ] =  None ,
95139        reask_messages : Optional [List [Dict ]] =  None ,
96140        messages : Optional [List [Dict ]] =  None ,
@@ -116,45 +160,6 @@ def for_pydantic(
116160        else :
117161            return  cast (NemoguardrailsGuard [Dict ], guard )
118162
119-     # create the callable 
120-     def  _custom_nemo_callable (self , * args , generate_kwargs , ** kwargs ):
121-         # .generate doesn't like temp 
122-         kwargs .pop ("temperature" , None )
123- 
124-         # msg_history, messages, prompt, and instruction all may or may not be present. 
125-         # if none of them are present, raise an error 
126-         # if messages is present, use that 
127-         # if msg_history is present, use 
128- 
129-         msg_history  =  kwargs .pop ("msg_history" , None )
130-         messages  =  kwargs .pop ("messages" , None )
131-         prompt  =  kwargs .pop ("prompt" , None )
132-         instructions  =  kwargs .pop ("instructions" , None )
133- 
134-         if  msg_history  is  not   None  and  messages  is  None :
135-             messages  =  msg_history 
136- 
137-         if  messages  is  None  and  msg_history  is  None :
138-             messages  =  []
139-             if  instructions  is  not   None :
140-                 messages .append ({"role" : "system" , "content" : instructions })
141-             if  prompt  is  not   None :
142-                 messages .append ({"role" : "system" , "content" : prompt })
143- 
144-         if  messages  ==  [] or  messages  is  None :
145-             raise  ValueError (
146-                 "messages, prompt, or instructions should be passed during a call." 
147-             )
148- 
149-         # kwargs["messages"] = messages 
150- 
151-         # return (self._nemorails.generate(**kwargs))["content"]  # type: ignore 
152-         if  not  generate_kwargs :
153-             generate_kwargs  =  {}
154-         return  (self ._nemorails .generate (messages = messages , ** generate_kwargs ))[  # type: ignore 
155-             "content" 
156-         ]
157- 
158163    @deprecated ( 
159164        "Use `for_rail_string` instead. This method will be removed in 0.6.x." , 
160165        category = None , 
@@ -190,3 +195,34 @@ def for_rail(cls, *args, **kwargs):
190195 `for_rail` is not implemented for NemoguardrailsGuard.
191196We recommend using the main constructor `NemoGuardrailsGuard(nemorails=nemorails)` 
192197or the `from_pydantic` method.""" )
198+ 
199+ 
200+ class  AsyncNemoguardrailsGuard (NemoguardrailsGuard , AsyncGuard , Generic [OT ]):
201+     def  __init__ (
202+         self ,
203+         nemorails : LLMRails ,
204+         * args ,
205+         ** kwargs ,
206+     ):
207+         super ().__init__ (nemorails , * args , ** kwargs )
208+         self ._generate  =  self ._nemorails .generate_async 
209+ 
210+     async  def  _custom_nemo_callable (self , * args , generate_kwargs , ** kwargs ):
211+         return  super ()._custom_nemo_callable (
212+             * args , generate_kwargs = generate_kwargs , ** kwargs 
213+         )
214+ 
215+     async  def  __call__ (  # type: ignore 
216+         self ,
217+         llm_api : Optional [Callable ] =  None ,
218+         generate_kwargs : Optional [Dict ] =  None ,
219+         * args ,
220+         ** kwargs ,
221+     ) ->  Union [
222+         ValidationOutcome [OT ],
223+         Awaitable [ValidationOutcome [OT ]],
224+         AsyncIterator [ValidationOutcome [OT ]],
225+     ]:
226+         return  await  super ().__call__ (
227+             llm_api = llm_api , generate_kwargs = generate_kwargs , * args , ** kwargs 
228+         )  # type: ignore 
0 commit comments