Skip to content

Commit bc7ff07

Browse files
committed
tests passing
1 parent 66a568b commit bc7ff07

25 files changed

+326
-212
lines changed

guardrails/classes/history/call.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from guardrails.classes.validation.validation_result import ValidationResult
1717
from guardrails.constants import error_status, fail_status, not_run_status, pass_status
1818
from guardrails.prompt.messages import Messages
19+
from guardrails.prompt import Prompt, Instructions
1920
from guardrails.classes.validation.validator_logs import ValidatorLogs
2021
from guardrails.actions.reask import (
2122
ReAsk,
@@ -93,10 +94,13 @@ def compiled_messages(self) -> Optional[str]:
9394
prompt_params = initial_inputs.prompt_params or {}
9495
compiled_messages = []
9596
for message in messages:
97+
content = message["content"].format(**prompt_params)
98+
if isinstance(content, (Prompt, Instructions)):
99+
content = content._source
96100
compiled_messages.append(
97101
{
98102
"role": message["role"],
99-
"content": message["content"].format(**prompt_params),
103+
"content": content,
100104
}
101105
)
102106

@@ -112,12 +116,28 @@ def reask_messages(self) -> Stack[Messages]:
112116
reasks = self.iterations.copy()
113117
initial_messages = reasks.first
114118
reasks.remove(initial_messages) # type: ignore
115-
return Stack(
116-
*[
117-
r.inputs.messages if r.inputs.messages is not None else None
118-
for r in reasks
119-
]
120-
)
119+
initial_inputs = self.iterations.first.inputs
120+
prompt_params = initial_inputs.prompt_params or {}
121+
compiled_reasks = []
122+
for reask in reasks:
123+
messages: Messages = reask.inputs.messages
124+
125+
if messages is None:
126+
compiled_reasks.append(None)
127+
else:
128+
compiled_messages = []
129+
for message in messages:
130+
content = message["content"].format(**prompt_params)
131+
if isinstance(content, (Prompt, Instructions)):
132+
content = content._source
133+
compiled_messages.append(
134+
{
135+
"role": message["role"],
136+
"content": content,
137+
}
138+
)
139+
compiled_reasks.append(compiled_messages)
140+
return Stack(*compiled_reasks)
121141

122142
return Stack()
123143

guardrails/classes/history/inputs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from guardrails.prompt.messages import Messages
1010
from guardrails.prompt.instructions import Instructions
1111

12+
1213
class Inputs(IInputs, ArbitraryModel):
1314
"""Inputs represent the input data that is passed into the validation loop.
1415
@@ -38,7 +39,9 @@ class Inputs(IInputs, ArbitraryModel):
3839
"provided by the user via Guard.parse.",
3940
default=None,
4041
)
41-
messages: Optional[Union[List[Dict[str, Union[str, Prompt, Instructions]]], Messages]] = Field(
42+
messages: Optional[
43+
Union[List[Dict[str, Union[str, Prompt, Instructions]]], Messages]
44+
] = Field(
4245
description="The message history provided by the user for chat model calls.",
4346
default=None,
4447
)
@@ -95,7 +98,6 @@ def to_dict(self) -> Dict[str, Any]:
9598
@classmethod
9699
def from_interface(cls, i_inputs: IInputs) -> "Inputs":
97100
deserialized_messages = None
98-
print("====== inputs.py: Inputs.from_interface() ======", i_inputs)
99101
if i_inputs.messages:
100102
deserialized_messages = []
101103
for msg in i_inputs.messages:

guardrails/guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def __call__(
841841
"""
842842

843843
messages = messages or self._exec_opts.messages or []
844-
print("==== messages is", messages)
844+
845845
# if messages is not None and not len(messages):
846846
# raise RuntimeError(
847847
# "You must provide messages. "

guardrails/llm_providers.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
cast,
1414
)
1515

16-
from guardrails.prompt import Prompt
16+
from guardrails.prompt import Prompt, Instructions
1717
from guardrails_api_client.models import LLMResource
1818

1919
from guardrails.errors import UserFacingException
@@ -25,6 +25,12 @@
2525

2626
from guardrails.types.inputs import MessageHistory
2727

28+
import warnings
29+
30+
from guardrails.utils.safe_get import safe_get
31+
from guardrails.telemetry import trace_llm_call, trace_operation
32+
33+
2834
# todo fix circular import
2935
def messages_string(messages: MessageHistory) -> str:
3036
messages_copy = ""
@@ -37,9 +43,6 @@ def messages_string(messages: MessageHistory) -> str:
3743
messages_copy += content
3844
return messages_copy
3945

40-
from guardrails.utils.safe_get import safe_get
41-
from guardrails.telemetry import trace_llm_call, trace_operation
42-
4346

4447
###
4548
# Synchronous wrappers
@@ -183,9 +186,7 @@ def _invoke_llm(
183186
"Install with `pip install litellm`"
184187
) from e
185188
if messages is not None:
186-
messages = litellm_messages(
187-
prompt=text, messages=messages
188-
)
189+
messages = litellm_messages(prompt=text, messages=messages)
189190
kwargs["messages"] = messages
190191

191192
trace_operation(
@@ -265,7 +266,11 @@ def _invoke_llm(
265266

266267
class HuggingFaceModelCallable(PromptCallableBase):
267268
def _invoke_llm(
268-
self, model_generate: Any, *args, messages: list[dict[str, str]], **kwargs
269+
self,
270+
model_generate: Any,
271+
*args,
272+
messages: list[dict[str, Union[str, Prompt, Instructions]]],
273+
**kwargs,
269274
) -> LLMResponse:
270275
try:
271276
import transformers # noqa: F401 # type: ignore
@@ -431,6 +436,14 @@ def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs):
431436
llm_api_args = inspect.getfullargspec(llm_api)
432437
if not llm_api_args.varkw:
433438
raise ValueError("Custom LLM callables must accept **kwargs!")
439+
if not llm_api_args.kwonlyargs or "messages" not in llm_api_args.kwonlyargs:
440+
warnings.warn(
441+
"We recommend including 'messages'"
442+
" as keyword-only arguments for custom LLM callables."
443+
" Doing so ensures these arguments are not unintentionally"
444+
" passed through to other calls via **kwargs.",
445+
UserWarning,
446+
)
434447
self.llm_api = llm_api
435448
super().__init__(*args, **kwargs)
436449

@@ -771,13 +784,16 @@ async def invoke_llm(
771784
class AsyncArbitraryCallable(AsyncPromptCallableBase):
772785
def __init__(self, llm_api: Callable, *args, **kwargs):
773786
llm_api_args = inspect.getfullargspec(llm_api)
774-
if not llm_api_args.args:
775-
raise ValueError(
776-
"Custom LLM callables must accept"
777-
" at least one positional argument for messages!"
778-
)
779787
if not llm_api_args.varkw:
780788
raise ValueError("Custom LLM callables must accept **kwargs!")
789+
if not llm_api_args.kwonlyargs or "messages" not in llm_api_args.kwonlyargs:
790+
warnings.warn(
791+
"We recommend including 'messages'"
792+
" as keyword-only arguments for custom LLM callables."
793+
" Doing so ensures these arguments are not unintentionally"
794+
" passed through to other calls via **kwargs.",
795+
UserWarning,
796+
)
781797
self.llm_api = llm_api
782798
super().__init__(*args, **kwargs)
783799

tests/integration_tests/mock_llm_outputs.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,26 +128,15 @@ def _invoke_llm(
128128
}
129129

130130
try:
131+
out_text = None
131132
if messages:
132133
if len(messages) == 2:
133134
key = (messages[0]["content"], messages[1]["content"])
134135
elif len(messages) == 1:
135136
key = (messages[0]["content"], None)
136-
out_text = mock_llm_responses[key]
137-
if prompt and instructions and not messages:
138-
out_text = mock_llm_responses[(prompt, instructions)]
139-
elif messages and not prompt and not instructions:
140-
if messages == entity_extraction.COMPILED_MSG_HISTORY:
141-
out_text = entity_extraction.LLM_OUTPUT
142-
elif (
143-
messages == string.MOVIE_MSG_HISTORY
144-
and base_model == pydantic.WITH_MSG_HISTORY
145-
):
146-
out_text = pydantic.MSG_HISTORY_LLM_OUTPUT_INCORRECT
147-
elif messages == string.MOVIE_MSG_HISTORY:
148-
out_text = string.MSG_LLM_OUTPUT_INCORRECT
149-
else:
150-
raise ValueError("messages not found")
137+
138+
if hasattr(mock_llm_responses[key], "read"):
139+
out_text = mock_llm_responses[key]
151140
else:
152141
raise ValueError(
153142
"specify either prompt and instructions " "or messages"
@@ -162,7 +151,8 @@ def _invoke_llm(
162151
print("\n prompt: \n", prompt)
163152
print("\n instructions: \n", instructions)
164153
print("\n messages: \n", messages)
165-
raise ValueError("Compiled prompt not found")
154+
print("\n base_model: \n", base_model)
155+
raise ValueError("Compiled prompt not found in mock llm response")
166156

167157

168158
class MockArbitraryCallable(ArbitraryCallable):

tests/integration_tests/test_assets/entity_extraction/compiled_prompt.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,3 @@ Given below is XML that describes the information to extract from this document
129129

130130

131131
ONLY return a valid JSON object (no other text is necessary). The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise.
132-
133-
134-
Json Output:
135-

tests/integration_tests/test_assets/entity_extraction/compiled_prompt_full_reask.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,3 @@ Given below is XML that describes the information to extract from this document
110110
</output>
111111

112112
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
113-
114-
115-
Json Output:
116-

tests/integration_tests/test_assets/entity_extraction/compiled_prompt_reask.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,3 @@ Given below is XML that describes the information to extract from this document
3939
</output>
4040

4141
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
42-
43-
44-
Json Output:
45-

tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_1.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,3 @@ Given below is XML that describes the information to extract from this document
129129

130130

131131
ONLY return a valid JSON object (no other text is necessary). The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise.
132-
133-
134-
Json Output:
135-

tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,3 @@ Here's an example of the structure:
9797
],
9898
"interest_rates": {}
9999
}
100-
101-
102-
Json Output:
103-

0 commit comments

Comments
 (0)