Skip to content

Commit 67feff4

Browse files
committed
more progress
1 parent 816ef5f commit 67feff4

21 files changed

+104
-109
lines changed

guardrails/actions/reask.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,7 @@ def get_reask_setup_for_string(
284284
)
285285

286286
instructions = None
287-
if exec_options.reask_instructions:
288-
instructions = Instructions(exec_options.reask_instructions)
289-
if instructions is None:
290-
instructions = Instructions("You are a helpful assistant.")
287+
instructions = Instructions("You are a helpful assistant.")
291288
instructions = instructions.format(
292289
output_schema=schema_prompt_content,
293290
xml_output_schema=xml_output_schema,

guardrails/guard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,7 @@ def __call__(
841841
"""
842842

843843
messages = messages or self._exec_opts.messages or []
844+
print("==== messages is", messages)
844845
# if messages is not None and not len(messages):
845846
# raise RuntimeError(
846847
# "You must provide messages. "

guardrails/llm_providers.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -416,16 +416,13 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons
416416
class ArbitraryCallable(PromptCallableBase):
417417
def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs):
418418
llm_api_args = inspect.getfullargspec(llm_api)
419+
if not llm_api_args.args:
420+
raise ValueError(
421+
"Custom LLM callables must accept"
422+
" at least one positional argument for messages!"
423+
)
419424
if not llm_api_args.varkw:
420425
raise ValueError("Custom LLM callables must accept **kwargs!")
421-
if not llm_api_args.kwonlyargs or "messages" not in llm_api_args.kwonlyargs:
422-
warnings.warn(
423-
"We recommend including 'messages'"
424-
" as keyword-only arguments for custom LLM callables."
425-
" Doing so ensures these arguments are not unintentionally"
426-
" passed through to other calls via **kwargs.",
427-
UserWarning,
428-
)
429426
self.llm_api = llm_api
430427
super().__init__(*args, **kwargs)
431428

@@ -766,16 +763,13 @@ async def invoke_llm(
766763
class AsyncArbitraryCallable(AsyncPromptCallableBase):
767764
def __init__(self, llm_api: Callable, *args, **kwargs):
768765
llm_api_args = inspect.getfullargspec(llm_api)
766+
if not llm_api_args.args:
767+
raise ValueError(
768+
"Custom LLM callables must accept"
769+
" at least one positional argument for messages!"
770+
)
769771
if not llm_api_args.varkw:
770772
raise ValueError("Custom LLM callables must accept **kwargs!")
771-
if not llm_api_args.kwonlyargs or "messages" not in llm_api_args.kwonlyargs:
772-
warnings.warn(
773-
"We recommend including 'messages'"
774-
" as keyword-only arguments for custom LLM callables."
775-
" Doing so ensures these arguments are not uninentionally"
776-
" passed through to other calls via **kwargs.",
777-
UserWarning,
778-
)
779773
self.llm_api = llm_api
780774
super().__init__(*args, **kwargs)
781775

guardrails/prompt/messages.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,16 @@ def format(
6262
"""Format the messages using the given keyword arguments."""
6363
formatted_messages = []
6464
for message in self.source:
65+
if isinstance(message["content"], str):
66+
msg_str = message["content"]
67+
else:
68+
msg_str = message["content"]._source
6569
# Only use the keyword arguments that are present in the message.
66-
vars = get_template_variables(message["content"])
70+
vars = get_template_variables(msg_str)
6771
filtered_kwargs = {k: v for k, v in kwargs.items() if k in vars}
6872

6973
# Return another instance of the class with the formatted message.
70-
formatted_message = Template(message["content"]).safe_substitute(
74+
formatted_message = Template(msg_str).safe_substitute(
7175
**filtered_kwargs
7276
)
7377
formatted_messages.append(

guardrails/run/async_stream_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,16 @@ async def async_step(
200200
def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str:
201201
"""Get the text from a chunk."""
202202
chunk_text = ""
203+
203204
try:
204205
finished = chunk.choices[0].finish_reason
205-
content = chunk.choices[0].text
206+
content = chunk.choices[0].delta.content
206207
if not finished and content:
207208
chunk_text = content
208209
except Exception:
209210
try:
210211
finished = chunk.choices[0].finish_reason
211-
content = chunk.choices[0].delta.content
212+
content = chunk.choices[0].text
212213
if not finished and content:
213214
chunk_text = content
214215
except Exception:

guardrails/run/runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from guardrails.actions.reask import NonParseableReAsk, ReAsk, introspect
3636
from guardrails.telemetry import trace_call, trace_step
3737

38-
3938
class Runner:
4039
"""Runner class that calls an LLM API with a prompt, and performs input and
4140
output validation.

guardrails/run/stream_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from guardrails.classes.output_type import OT, OutputTypes
66
from guardrails.classes.validation_outcome import ValidationOutcome
77
from guardrails.llm_providers import (
8+
LiteLLMCallable,
89
PromptCallableBase,
910
)
1011
from guardrails.run.runner import Runner
@@ -250,13 +251,13 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st
250251
chunk_text = ""
251252
try:
252253
finished = chunk.choices[0].finish_reason
253-
content = chunk.choices[0].text
254+
content = chunk.choices[0].delta.content
254255
if not finished and content:
255256
chunk_text = content
256257
except Exception:
257258
try:
258259
finished = chunk.choices[0].finish_reason
259-
content = chunk.choices[0].delta.content
260+
content = chunk.choices[0].text
260261
if not finished and content:
261262
chunk_text = content
262263
except Exception:

tests/integration_tests/mock_llm_outputs.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _invoke_llm(
6666
self,
6767
prompt=None,
6868
instructions=None,
69-
msg_history=None,
69+
messages=None,
7070
base_model=None,
7171
*args,
7272
**kwargs,
@@ -128,28 +128,26 @@ def _invoke_llm(
128128
}
129129

130130
try:
131-
if msg_history:
132-
key = (msg_history[0]["content"], msg_history[1]["content"])
133-
print("=========trying key", key)
131+
if messages:
132+
key = (messages[0]["content"], messages[1]["content"])
134133
out_text = mock_llm_responses[key]
135-
print("========found out text", out_text)
136-
if prompt and instructions and not msg_history:
134+
if prompt and instructions and not messages:
137135
out_text = mock_llm_responses[(prompt, instructions)]
138-
elif msg_history and not prompt and not instructions:
139-
if msg_history == entity_extraction.COMPILED_MSG_HISTORY:
136+
elif messages and not prompt and not instructions:
137+
if messages == entity_extraction.COMPILED_MSG_HISTORY:
140138
out_text = entity_extraction.LLM_OUTPUT
141139
elif (
142-
msg_history == string.MOVIE_MSG_HISTORY
140+
messages == string.MOVIE_MSG_HISTORY
143141
and base_model == pydantic.WITH_MSG_HISTORY
144142
):
145143
out_text = pydantic.MSG_HISTORY_LLM_OUTPUT_INCORRECT
146-
elif msg_history == string.MOVIE_MSG_HISTORY:
144+
elif messages == string.MOVIE_MSG_HISTORY:
147145
out_text = string.MSG_LLM_OUTPUT_INCORRECT
148146
else:
149-
raise ValueError("msg_history not found")
147+
raise ValueError("messages not found")
150148
else:
151149
raise ValueError(
152-
"specify either prompt and instructions " "or msg_history"
150+
"specify either prompt and instructions " "or messages"
153151
)
154152
return LLMResponse(
155153
output=out_text,
@@ -160,7 +158,7 @@ def _invoke_llm(
160158
print("Unrecognized prompt!")
161159
print("\n prompt: \n", prompt)
162160
print("\n instructions: \n", instructions)
163-
print("\n msg_history: \n", msg_history)
161+
print("\n messages: \n", messages)
164162
raise ValueError("Compiled prompt not found")
165163

166164

tests/integration_tests/test_assets/custom_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33

44
def mock_llm(
5+
messages,
56
*args,
6-
messages: Optional[List[Dict[str, str]]] = None,
77
**kwargs,
88
) -> str:
99
return ""
1010

1111

1212
async def mock_async_llm(
13+
messages,
1314
*args,
14-
messages: Optional[List[Dict[str, str]]] = None,
1515
**kwargs,
1616
) -> str:
1717
return ""

tests/integration_tests/test_assets/string/compiled_prompt.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,3 @@
22
Given the following ingredients, what would you call this pizza?
33

44
tomato, cheese, sour cream
5-
6-
7-
String Output:
8-

0 commit comments

Comments
 (0)