Skip to content

Commit b0baed1

Browse files
committed
reapply litellm updates to support only messages llm kwarg
1 parent eb212ba commit b0baed1

File tree

17 files changed

+259
-1403
lines changed

17 files changed

+259
-1403
lines changed

guardrails/actions/reask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def get_reask_setup(
499499
use_full_schema: Optional[bool] = False,
500500
prompt_params: Optional[Dict[str, Any]] = None,
501501
exec_options: Optional[GuardExecutionOptions] = None,
502-
) -> Tuple[Dict[str, Any], Prompt, Instructions]:
502+
) -> Tuple[Dict[str, Any], Messages]:
503503
prompt_params = prompt_params or {}
504504
exec_options = exec_options or GuardExecutionOptions()
505505

guardrails/async_guard.py

Lines changed: 31 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,17 @@ def from_pydantic(
9292
cls,
9393
output_class: ModelOrListOfModels,
9494
*,
95-
prompt: Optional[str] = None,
96-
instructions: Optional[str] = None,
95+
messages: Optional[List[Dict]] = None,
9796
num_reasks: Optional[int] = None,
98-
reask_prompt: Optional[str] = None,
99-
reask_instructions: Optional[str] = None,
10097
reask_messages: Optional[List[Dict]] = None,
10198
tracer: Optional[Tracer] = None,
10299
name: Optional[str] = None,
103100
description: Optional[str] = None,
104101
):
105102
guard = super().from_pydantic(
106103
output_class,
107-
prompt=prompt,
108-
instructions=instructions,
109104
num_reasks=num_reasks,
110-
reask_prompt=reask_prompt,
111-
reask_instructions=reask_instructions,
105+
messages=messages,
112106
reask_messages=reask_messages,
113107
tracer=tracer,
114108
name=name,
@@ -125,10 +119,8 @@ def from_string(
125119
validators: Sequence[Validator],
126120
*,
127121
string_description: Optional[str] = None,
128-
prompt: Optional[str] = None,
129-
instructions: Optional[str] = None,
130-
reask_prompt: Optional[str] = None,
131-
reask_instructions: Optional[str] = None,
122+
messages: Optional[List[Dict]] = None,
123+
reask_messages: Optional[List[Dict]] = None,
132124
num_reasks: Optional[int] = None,
133125
tracer: Optional[Tracer] = None,
134126
name: Optional[str] = None,
@@ -137,10 +129,8 @@ def from_string(
137129
guard = super().from_string(
138130
validators,
139131
string_description=string_description,
140-
prompt=prompt,
141-
instructions=instructions,
142-
reask_prompt=reask_prompt,
143-
reask_instructions=reask_instructions,
132+
messages=messages,
133+
reask_messages=reask_messages,
144134
num_reasks=num_reasks,
145135
tracer=tracer,
146136
name=name,
@@ -178,9 +168,7 @@ async def _execute(
178168
llm_output: Optional[str] = None,
179169
prompt_params: Optional[Dict] = None,
180170
num_reasks: Optional[int] = None,
181-
prompt: Optional[str] = None,
182-
instructions: Optional[str] = None,
183-
msg_history: Optional[List[Dict]] = None,
171+
messages: Optional[List[Dict]] = None,
184172
metadata: Optional[Dict],
185173
full_schema_reask: Optional[bool] = None,
186174
**kwargs,
@@ -192,10 +180,8 @@ async def _execute(
192180
self._fill_validator_map()
193181
self._fill_validators()
194182
metadata = metadata or {}
195-
if not llm_output and llm_api and not (prompt or msg_history):
196-
raise RuntimeError(
197-
"'prompt' or 'msg_history' must be provided in order to call an LLM!"
198-
)
183+
if not llm_output and llm_api and not (messages):
184+
raise RuntimeError("'messages' must be provided in order to call an LLM!")
199185
# check if validator requirements are fulfilled
200186
missing_keys = verify_metadata_requirements(metadata, self._validators)
201187
if missing_keys:
@@ -210,9 +196,7 @@ async def __exec(
210196
llm_output: Optional[str] = None,
211197
prompt_params: Optional[Dict] = None,
212198
num_reasks: Optional[int] = None,
213-
prompt: Optional[str] = None,
214-
instructions: Optional[str] = None,
215-
msg_history: Optional[List[Dict]] = None,
199+
messages: Optional[List[Dict]] = None,
216200
metadata: Optional[Dict] = None,
217201
full_schema_reask: Optional[bool] = None,
218202
**kwargs,
@@ -245,14 +229,6 @@ async def __exec(
245229
("guard_id", self.id),
246230
("user_id", self._user_id),
247231
("llm_api", llm_api_str),
248-
(
249-
"custom_reask_prompt",
250-
self._exec_opts.reask_prompt is not None,
251-
),
252-
(
253-
"custom_reask_instructions",
254-
self._exec_opts.reask_instructions is not None,
255-
),
256232
(
257233
"custom_reask_messages",
258234
self._exec_opts.reask_messages is not None,
@@ -273,13 +249,10 @@ async def __exec(
273249
"This should never happen."
274250
)
275251

276-
input_prompt = prompt or self._exec_opts.prompt
277-
input_instructions = instructions or self._exec_opts.instructions
252+
messages = messages or self._exec_opts.messages
278253
call_inputs = CallInputs(
279254
llm_api=llm_api,
280-
prompt=input_prompt,
281-
instructions=input_instructions,
282-
msg_history=msg_history,
255+
messages=messages,
283256
prompt_params=prompt_params,
284257
num_reasks=self._num_reasks,
285258
metadata=metadata,
@@ -298,9 +271,7 @@ async def __exec(
298271
prompt_params=prompt_params,
299272
metadata=metadata,
300273
full_schema_reask=full_schema_reask,
301-
prompt=prompt,
302-
instructions=instructions,
303-
msg_history=msg_history,
274+
messages=messages,
304275
*args,
305276
**kwargs,
306277
)
@@ -315,9 +286,7 @@ async def __exec(
315286
llm_output=llm_output,
316287
prompt_params=prompt_params,
317288
num_reasks=self._num_reasks,
318-
prompt=prompt,
319-
instructions=instructions,
320-
msg_history=msg_history,
289+
messages=messages,
321290
metadata=metadata,
322291
full_schema_reask=full_schema_reask,
323292
call_log=call_log,
@@ -343,9 +312,7 @@ async def __exec(
343312
llm_output=llm_output,
344313
prompt_params=prompt_params,
345314
num_reasks=num_reasks,
346-
prompt=prompt,
347-
instructions=instructions,
348-
msg_history=msg_history,
315+
messages=messages,
349316
metadata=metadata,
350317
full_schema_reask=full_schema_reask,
351318
*args,
@@ -362,9 +329,7 @@ async def _exec(
362329
num_reasks: int = 0, # Should be defined at this point
363330
metadata: Dict, # Should be defined at this point
364331
full_schema_reask: bool = False, # Should be defined at this point
365-
prompt: Optional[str],
366-
instructions: Optional[str],
367-
msg_history: Optional[List[Dict]],
332+
messages: Optional[List[Dict]],
368333
**kwargs,
369334
) -> Union[
370335
ValidationOutcome[OT],
@@ -377,9 +342,7 @@ async def _exec(
377342
llm_api: The LLM API to call asynchronously (e.g. openai.Completion.acreate)
378343
prompt_params: The parameters to pass to the prompt.format() method.
379344
num_reasks: The max times to re-ask the LLM for invalid output.
380-
prompt: The prompt to use for the LLM.
381-
instructions: Instructions for chat models.
382-
msg_history: The message history to pass to the LLM.
345+
messages: The message history to pass to the LLM.
383346
metadata: Metadata to pass to the validators.
384347
full_schema_reask: When reasking, whether to regenerate the full schema
385348
or just the incorrect values.
@@ -396,9 +359,7 @@ async def _exec(
396359
output_schema=self.output_schema.to_dict(),
397360
num_reasks=num_reasks,
398361
validation_map=self._validator_map,
399-
prompt=prompt,
400-
instructions=instructions,
401-
msg_history=msg_history,
362+
messages=messages,
402363
api=api,
403364
metadata=metadata,
404365
output=llm_output,
@@ -418,9 +379,7 @@ async def _exec(
418379
output_schema=self.output_schema.to_dict(),
419380
num_reasks=num_reasks,
420381
validation_map=self._validator_map,
421-
prompt=prompt,
422-
instructions=instructions,
423-
msg_history=msg_history,
382+
messages=messages,
424383
api=api,
425384
metadata=metadata,
426385
output=llm_output,
@@ -441,9 +400,7 @@ async def __call__(
441400
*args,
442401
prompt_params: Optional[Dict] = None,
443402
num_reasks: Optional[int] = 1,
444-
prompt: Optional[str] = None,
445-
instructions: Optional[str] = None,
446-
msg_history: Optional[List[Dict]] = None,
403+
messages: Optional[List[Dict]] = None,
447404
metadata: Optional[Dict] = None,
448405
full_schema_reask: Optional[bool] = None,
449406
**kwargs,
@@ -460,9 +417,7 @@ async def __call__(
460417
(e.g. openai.completions.create or openai.chat.completions.create)
461418
prompt_params: The parameters to pass to the prompt.format() method.
462419
num_reasks: The max times to re-ask the LLM for invalid output.
463-
prompt: The prompt to use for the LLM.
464-
instructions: Instructions for chat models.
465-
msg_history: The message history to pass to the LLM.
420+
messages: The message history to pass to the LLM.
466421
metadata: Metadata to pass to the validators.
467422
full_schema_reask: When reasking, whether to regenerate the full schema
468423
or just the incorrect values.
@@ -473,16 +428,13 @@ async def __call__(
473428
The raw text output from the LLM and the validated output.
474429
"""
475430

476-
instructions = instructions or self._exec_opts.instructions
477-
prompt = prompt or self._exec_opts.prompt
478-
msg_history = msg_history or kwargs.pop("messages", None) or []
431+
messages = messages or kwargs.pop("messages", None) or []
479432

480-
if prompt is None:
481-
if msg_history is not None and not len(msg_history):
482-
raise RuntimeError(
483-
"You must provide a prompt if msg_history is empty. "
484-
"Alternatively, you can provide a prompt in the Schema constructor."
485-
)
433+
if messages is not None and not len(messages):
434+
raise RuntimeError(
435+
"You must provide a prompt if messages is empty. "
436+
"Alternatively, you can provide a prompt in the Schema constructor."
437+
)
486438

487439
return await trace_async_guard_execution(
488440
self.name,
@@ -493,9 +445,7 @@ async def __call__(
493445
llm_api=llm_api,
494446
prompt_params=prompt_params,
495447
num_reasks=num_reasks,
496-
prompt=prompt,
497-
instructions=instructions,
498-
msg_history=msg_history,
448+
messages=messages,
499449
metadata=metadata,
500450
full_schema_reask=full_schema_reask,
501451
**kwargs,
@@ -538,14 +488,8 @@ async def parse(
538488
if llm_api is None
539489
else 1
540490
)
541-
default_prompt = self._exec_opts.prompt if llm_api is not None else None
542-
prompt = kwargs.pop("prompt", default_prompt)
543-
544-
default_instructions = self._exec_opts.instructions if llm_api else None
545-
instructions = kwargs.pop("instructions", default_instructions)
546-
547-
default_msg_history = self._exec_opts.msg_history if llm_api else None
548-
msg_history = kwargs.pop("msg_history", default_msg_history)
491+
default_messages = self._exec_opts.messages if llm_api else None
492+
messages = kwargs.pop("messages", default_messages)
549493

550494
return await trace_async_guard_execution( # type: ignore
551495
self.name,
@@ -557,9 +501,7 @@ async def parse(
557501
llm_api=llm_api,
558502
prompt_params=prompt_params,
559503
num_reasks=final_num_reasks,
560-
prompt=prompt,
561-
instructions=instructions,
562-
msg_history=msg_history,
504+
messages=messages,
563505
metadata=metadata,
564506
full_schema_reask=full_schema_reask,
565507
**kwargs,

guardrails/classes/execution/guard_execution_options.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
@dataclass
66
class GuardExecutionOptions:
7-
prompt: Optional[str] = None
8-
instructions: Optional[str] = None
9-
msg_history: Optional[List[Dict]] = None
107
messages: Optional[List[Dict]] = None
11-
reask_prompt: Optional[str] = None
12-
reask_instructions: Optional[str] = None
138
reask_messages: Optional[List[Dict]] = None
149
num_reasks: Optional[int] = None

0 commit comments

Comments
 (0)