1818from guardrails .schema import Schema , StringSchema
1919from guardrails .utils .exception_utils import UserFacingException
2020from guardrails .utils .llm_response import LLMResponse
21+ from guardrails .utils .openai_utils import OPENAI_VERSION
2122from guardrails .utils .reask_utils import (
2223 NonParseableReAsk ,
2324 ReAsk ,
@@ -39,7 +40,6 @@ class Runner:
3940 Args:
4041 prompt: The prompt to use.
4142 api: The LLM API to call, which should return a string.
42- input_schema: The input schema to use for validation.
4343 output_schema: The output schema to use for validation.
4444 num_reasks: The maximum number of times to reask the LLM in case of
4545 validation failure, defaults to 0.
@@ -1120,16 +1120,28 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None):
11201120 instructions = self .instructions ,
11211121 prompt = self .prompt ,
11221122 api = self .api ,
1123- input_schema = self .input_schema ,
1123+ prompt_schema = self .prompt_schema ,
1124+ instructions_schema = self .instructions_schema ,
1125+ msg_history_schema = self .msg_history_schema ,
11241126 output_schema = self .output_schema ,
11251127 num_reasks = self .num_reasks ,
11261128 metadata = self .metadata ,
11271129 ):
1128- instructions , prompt , msg_history , input_schema , output_schema = (
1130+ (
1131+ instructions ,
1132+ prompt ,
1133+ msg_history ,
1134+ prompt_schema ,
1135+ instructions_schema ,
1136+ msg_history_schema ,
1137+ output_schema ,
1138+ ) = (
11291139 self .instructions ,
11301140 self .prompt ,
11311141 self .msg_history ,
1132- self .input_schema ,
1142+ self .prompt_schema ,
1143+ self .instructions_schema ,
1144+ self .msg_history_schema ,
11331145 self .output_schema ,
11341146 )
11351147
@@ -1140,7 +1152,9 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None):
11401152 prompt = prompt ,
11411153 msg_history = msg_history ,
11421154 prompt_params = prompt_params ,
1143- input_schema = input_schema ,
1155+ prompt_schema = prompt_schema ,
1156+ instructions_schema = instructions_schema ,
1157+ msg_history_schema = msg_history_schema ,
11441158 output_schema = output_schema ,
11451159 output = self .output ,
11461160 call_log = call_log ,
@@ -1154,7 +1168,9 @@ def step(
11541168 prompt : Optional [Prompt ],
11551169 msg_history : Optional [List [Dict ]],
11561170 prompt_params : Dict ,
1157- input_schema : Optional [Schema ],
1171+ prompt_schema : Optional [StringSchema ],
1172+ instructions_schema : Optional [StringSchema ],
1173+ msg_history_schema : Optional [StringSchema ],
11581174 output_schema : Schema ,
11591175 call_log : Call ,
11601176 output : Optional [str ] = None ,
@@ -1181,7 +1197,9 @@ def step(
11811197 instructions = instructions ,
11821198 prompt = prompt ,
11831199 prompt_params = prompt_params ,
1184- input_schema = input_schema ,
1200+ prompt_schema = prompt_schema ,
1201+ instructions_schema = instructions_schema ,
1202+ msg_history_schema = msg_history_schema ,
11851203 output_schema = output_schema ,
11861204 ):
11871205 # Prepare: run pre-processing, and input validation.
@@ -1191,13 +1209,16 @@ def step(
11911209 msg_history = None
11921210 else :
11931211 instructions , prompt , msg_history = self .prepare (
1212+ call_log ,
11941213 index ,
11951214 instructions ,
11961215 prompt ,
11971216 msg_history ,
11981217 prompt_params ,
11991218 api ,
1200- input_schema ,
1219+ prompt_schema ,
1220+ instructions_schema ,
1221+ msg_history_schema ,
12011222 output_schema ,
12021223 )
12031224
@@ -1209,7 +1230,6 @@ def step(
12091230 llm_response = self .call (
12101231 index , instructions , prompt , msg_history , api , output
12111232 )
1212- # iteration.outputs.llm_response_info = llm_response
12131233
12141234 # Get the stream (generator) from the LLMResponse
12151235 stream = llm_response .stream_output
@@ -1285,24 +1305,31 @@ def step(
12851305
12861306 def get_chunk_text (self , chunk : Any , api : Union [PromptCallableBase , None ]) -> str :
12871307 """Get the text from a chunk."""
1308+ chunk_text = ""
12881309 if isinstance (api , OpenAICallable ):
1289- finished = chunk ["choices" ][0 ]["finish_reason" ]
1290- if finished :
1291- chunk_text = ""
1310+ if OPENAI_VERSION .startswith ("0" ):
1311+ finished = chunk ["choices" ][0 ]["finish_reason" ]
1312+ if "text" in chunk ["choices" ][0 ]:
1313+ content = chunk ["choices" ][0 ]["text" ]
1314+ if not finished and content :
1315+ chunk_text = content
12921316 else :
1293- if "text" not in chunk [ " choices" ] [0 ]:
1294- chunk_text = ""
1295- else :
1296- chunk_text = chunk [ "choices" ][ 0 ][ "text" ]
1317+ finished = chunk . choices [0 ]. finish_reason
1318+ content = chunk . choices [ 0 ]. text
1319+ if not finished and content :
1320+ chunk_text = content
12971321 elif isinstance (api , OpenAIChatCallable ):
1298- finished = chunk ["choices" ][0 ]["finish_reason" ]
1299- if finished :
1300- chunk_text = ""
1322+ if OPENAI_VERSION .startswith ("0" ):
1323+ finished = chunk ["choices" ][0 ]["finish_reason" ]
1324+ if "content" in chunk ["choices" ][0 ]["delta" ]:
1325+ content = chunk ["choices" ][0 ]["delta" ]["content" ]
1326+ if not finished and content :
1327+ chunk_text = content
13011328 else :
1302- if "content" not in chunk [ " choices" ] [0 ][ "delta" ]:
1303- chunk_text = ""
1304- else :
1305- chunk_text = chunk [ "choices" ][ 0 ][ "delta" ][ " content" ]
1329+ finished = chunk . choices [0 ]. finish_reason
1330+ content = chunk . choices [ 0 ]. delta . content
1331+ if not finished and content :
1332+ chunk_text = content
13061333 else :
13071334 try :
13081335 chunk_text = chunk
0 commit comments