Skip to content

Commit f4107fc

Browse files
FIx merge conflicts, fix bug for openai v1.x
1 parent 8b4c733 commit f4107fc

File tree

3 files changed

+58
-27
lines changed

3 files changed

+58
-27
lines changed

guardrails/guard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,9 @@ def _call_sync(
394394
prompt=prompt_obj,
395395
msg_history=msg_history_obj,
396396
api=get_llm_ask(llm_api, *args, **kwargs),
397-
input_schema=self.input_schema,
397+
prompt_schema=self.prompt_schema,
398+
instructions_schema=self.instructions_schema,
399+
msg_history_schema=self.msg_history_schema,
398400
output_schema=self.output_schema,
399401
num_reasks=num_reasks,
400402
metadata=metadata,
@@ -410,7 +412,9 @@ def _call_sync(
410412
prompt=prompt_obj,
411413
msg_history=msg_history_obj,
412414
api=get_llm_ask(llm_api, *args, **kwargs),
413-
input_schema=self.input_schema,
415+
prompt_schema=self.prompt_schema,
416+
instructions_schema=self.instructions_schema,
417+
msg_history_schema=self.msg_history_schema,
414418
output_schema=self.output_schema,
415419
num_reasks=num_reasks,
416420
metadata=metadata,

guardrails/llm_providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def _invoke_llm(
152152
"You must pass in either `text` or `msg_history` to `guard.__call__`."
153153
)
154154

155-
# Configure function calling if applicable
156-
if base_model:
155+
# Configure function calling if applicable (only for non-streaming)
156+
if base_model and not kwargs.get("stream", False):
157157
function_params = [convert_pydantic_model_to_openai_fn(base_model)]
158158
if function_call is None:
159159
function_call = {"name": function_params[0]["name"]}

guardrails/run.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from guardrails.schema import Schema, StringSchema
1919
from guardrails.utils.exception_utils import UserFacingException
2020
from guardrails.utils.llm_response import LLMResponse
21+
from guardrails.utils.openai_utils import OPENAI_VERSION
2122
from guardrails.utils.reask_utils import (
2223
NonParseableReAsk,
2324
ReAsk,
@@ -39,7 +40,6 @@ class Runner:
3940
Args:
4041
prompt: The prompt to use.
4142
api: The LLM API to call, which should return a string.
42-
input_schema: The input schema to use for validation.
4343
output_schema: The output schema to use for validation.
4444
num_reasks: The maximum number of times to reask the LLM in case of
4545
validation failure, defaults to 0.
@@ -1120,16 +1120,28 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None):
11201120
instructions=self.instructions,
11211121
prompt=self.prompt,
11221122
api=self.api,
1123-
input_schema=self.input_schema,
1123+
prompt_schema=self.prompt_schema,
1124+
instructions_schema=self.instructions_schema,
1125+
msg_history_schema=self.msg_history_schema,
11241126
output_schema=self.output_schema,
11251127
num_reasks=self.num_reasks,
11261128
metadata=self.metadata,
11271129
):
1128-
instructions, prompt, msg_history, input_schema, output_schema = (
1130+
(
1131+
instructions,
1132+
prompt,
1133+
msg_history,
1134+
prompt_schema,
1135+
instructions_schema,
1136+
msg_history_schema,
1137+
output_schema,
1138+
) = (
11291139
self.instructions,
11301140
self.prompt,
11311141
self.msg_history,
1132-
self.input_schema,
1142+
self.prompt_schema,
1143+
self.instructions_schema,
1144+
self.msg_history_schema,
11331145
self.output_schema,
11341146
)
11351147

@@ -1140,7 +1152,9 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None):
11401152
prompt=prompt,
11411153
msg_history=msg_history,
11421154
prompt_params=prompt_params,
1143-
input_schema=input_schema,
1155+
prompt_schema=prompt_schema,
1156+
instructions_schema=instructions_schema,
1157+
msg_history_schema=msg_history_schema,
11441158
output_schema=output_schema,
11451159
output=self.output,
11461160
call_log=call_log,
@@ -1154,7 +1168,9 @@ def step(
11541168
prompt: Optional[Prompt],
11551169
msg_history: Optional[List[Dict]],
11561170
prompt_params: Dict,
1157-
input_schema: Optional[Schema],
1171+
prompt_schema: Optional[StringSchema],
1172+
instructions_schema: Optional[StringSchema],
1173+
msg_history_schema: Optional[StringSchema],
11581174
output_schema: Schema,
11591175
call_log: Call,
11601176
output: Optional[str] = None,
@@ -1181,7 +1197,9 @@ def step(
11811197
instructions=instructions,
11821198
prompt=prompt,
11831199
prompt_params=prompt_params,
1184-
input_schema=input_schema,
1200+
prompt_schema=prompt_schema,
1201+
instructions_schema=instructions_schema,
1202+
msg_history_schema=msg_history_schema,
11851203
output_schema=output_schema,
11861204
):
11871205
# Prepare: run pre-processing, and input validation.
@@ -1191,13 +1209,16 @@ def step(
11911209
msg_history = None
11921210
else:
11931211
instructions, prompt, msg_history = self.prepare(
1212+
call_log,
11941213
index,
11951214
instructions,
11961215
prompt,
11971216
msg_history,
11981217
prompt_params,
11991218
api,
1200-
input_schema,
1219+
prompt_schema,
1220+
instructions_schema,
1221+
msg_history_schema,
12011222
output_schema,
12021223
)
12031224

@@ -1209,7 +1230,6 @@ def step(
12091230
llm_response = self.call(
12101231
index, instructions, prompt, msg_history, api, output
12111232
)
1212-
# iteration.outputs.llm_response_info = llm_response
12131233

12141234
# Get the stream (generator) from the LLMResponse
12151235
stream = llm_response.stream_output
@@ -1285,24 +1305,31 @@ def step(
12851305

12861306
def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str:
12871307
"""Get the text from a chunk."""
1308+
chunk_text = ""
12881309
if isinstance(api, OpenAICallable):
1289-
finished = chunk["choices"][0]["finish_reason"]
1290-
if finished:
1291-
chunk_text = ""
1310+
if OPENAI_VERSION.startswith("0"):
1311+
finished = chunk["choices"][0]["finish_reason"]
1312+
if "text" in chunk["choices"][0]:
1313+
content = chunk["choices"][0]["text"]
1314+
if not finished and content:
1315+
chunk_text = content
12921316
else:
1293-
if "text" not in chunk["choices"][0]:
1294-
chunk_text = ""
1295-
else:
1296-
chunk_text = chunk["choices"][0]["text"]
1317+
finished = chunk.choices[0].finish_reason
1318+
content = chunk.choices[0].text
1319+
if not finished and content:
1320+
chunk_text = content
12971321
elif isinstance(api, OpenAIChatCallable):
1298-
finished = chunk["choices"][0]["finish_reason"]
1299-
if finished:
1300-
chunk_text = ""
1322+
if OPENAI_VERSION.startswith("0"):
1323+
finished = chunk["choices"][0]["finish_reason"]
1324+
if "content" in chunk["choices"][0]["delta"]:
1325+
content = chunk["choices"][0]["delta"]["content"]
1326+
if not finished and content:
1327+
chunk_text = content
13011328
else:
1302-
if "content" not in chunk["choices"][0]["delta"]:
1303-
chunk_text = ""
1304-
else:
1305-
chunk_text = chunk["choices"][0]["delta"]["content"]
1329+
finished = chunk.choices[0].finish_reason
1330+
content = chunk.choices[0].delta.content
1331+
if not finished and content:
1332+
chunk_text = content
13061333
else:
13071334
try:
13081335
chunk_text = chunk

0 commit comments

Comments
 (0)