11import  copy 
22from  functools  import  partial 
3- from  typing  import  Any , Dict , List , Optional , Union ,  cast 
3+ from  typing  import  Any , Dict , List , Optional , cast 
44
55
66from  guardrails  import  validator_service 
1111from  guardrails .llm_providers  import  AsyncPromptCallableBase 
1212from  guardrails .logger  import  set_scope 
1313from  guardrails .run .runner  import  Runner 
14- from  guardrails .run .utils  import  messages_source ,  messages_string 
14+ from  guardrails .run .utils  import  messages_source 
1515from  guardrails .schema .validator  import  schema_validation 
1616from  guardrails .hub_telemetry .hub_tracing  import  async_trace 
1717from  guardrails .types .inputs  import  MessageHistory 
2525from  guardrails .constants  import  fail_status 
2626from  guardrails .prompt  import  Prompt 
2727
28+ 
2829class  AsyncRunner (Runner ):
2930    def  __init__ (
3031        self ,
@@ -331,9 +332,7 @@ async def prepare_messages(
331332            formatted_messages .append (msg_copy )
332333
333334        if  "messages"  in  self .validation_map :
334-             await  self .validate_messages (
335-                 call_log , formatted_messages , attempt_number 
336-             )
335+             await  self .validate_messages (call_log , formatted_messages , attempt_number )
337336
338337        return  formatted_messages 
339338
@@ -348,9 +347,11 @@ async def validate_messages(
348347                else  msg ["content" ]
349348            )
350349            inputs  =  Inputs (
351-                         llm_output = content ,
352-                     )
353-             iteration  =  Iteration (call_id = call_log .id , index = attempt_number , inputs = inputs )
350+                 llm_output = content ,
351+             )
352+             iteration  =  Iteration (
353+                 call_id = call_log .id , index = attempt_number , inputs = inputs 
354+             )
354355            call_log .iterations .insert (0 , iteration )
355356            value , _metadata  =  await  validator_service .async_validate (
356357                value = content ,
@@ -364,7 +365,7 @@ async def validate_messages(
364365            validated_msg  =  validator_service .post_process_validation (
365366                value , attempt_number , iteration , OutputTypes .STRING 
366367            )
367-              
368+ 
368369            iteration .outputs .validation_response  =  validated_msg 
369370
370371            if  isinstance (validated_msg , ReAsk ):
@@ -374,4 +375,4 @@ async def validate_messages(
374375
375376            msg ["content" ] =  cast (str , validated_msg )
376377
377-         return  messages   # type: ignore 
378+         return  messages   # type: ignore 
0 commit comments