Skip to content

Commit e684739

Browse files
committed
more progress
1 parent b97caa2 commit e684739

File tree

11 files changed

+150
-153
lines changed

11 files changed

+150
-153
lines changed

guardrails/formatters/json_formatter.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,16 @@ def fn(
105105
messages: Optional[List[Dict[str, str]]] = None,
106106
**kwargs,
107107
) -> str:
108+
prompt = ""
109+
for msg in messages:
110+
prompt += msg["content"]
111+
108112
return json.dumps(
109113
Jsonformer(
110114
model=model.model,
111115
tokenizer=model.tokenizer,
112116
json_schema=self.output_schema,
113-
messages=messages,
117+
prompt=prompt
114118
)()
115119
)
116120

@@ -127,12 +131,16 @@ def fn(
127131
messages: Optional[List[Dict[str, str]]] = None,
128132
**kwargs,
129133
) -> str:
134+
prompt = ""
135+
for msg in messages:
136+
prompt += msg["content"]
137+
130138
return json.dumps(
131139
Jsonformer(
132140
model=model,
133141
tokenizer=tokenizer,
134142
json_schema=self.output_schema,
135-
messages=messages,
143+
prompt=prompt
136144
)()
137145
)
138146

guardrails/guard.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,8 @@ def _execute(
647647
reask_messages=reask_messages,
648648
)
649649
metadata = metadata or {}
650-
if not llm_output and llm_api and not (messages):
651-
raise RuntimeError("'messages' must be provided in order to call an LLM!")
650+
# if not llm_output and llm_api and not (messages):
651+
# raise RuntimeError("'messages' must be provided in order to call an LLM!")
652652

653653
# check if validator requirements are fulfilled
654654
missing_keys = verify_metadata_requirements(metadata, self._validators)
@@ -841,11 +841,11 @@ def __call__(
841841
"""
842842

843843
messages = messages or self._exec_opts.messages or []
844-
if messages is not None and not len(messages):
845-
raise RuntimeError(
846-
"You must provide messages. "
847-
"Alternatively, you can provide a prompt in the Schema constructor."
848-
)
844+
# if messages is not None and not len(messages):
845+
# raise RuntimeError(
846+
# "You must provide messages. "
847+
# "Alternatively, you can provide a prompt in the Schema constructor."
848+
# )
849849
return trace_guard_execution(
850850
self.name,
851851
self.history,

guardrails/llm_providers.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def nonchat_prompt(prompt: str, instructions: Optional[str] = None) -> str:
4343
def chat_prompt(
4444
prompt: Optional[str],
4545
instructions: Optional[str] = None,
46-
msg_history: Optional[List[Dict]] = None,
46+
messages: Optional[List[Dict]] = None,
4747
) -> List[Dict[str, str]]:
4848
"""Prepare final prompt for chat engine."""
49-
if msg_history:
50-
return msg_history
49+
if messages:
50+
return messages
5151
if prompt is None:
5252
raise PromptCallableException(
53-
"You must pass in either `text` or `msg_history` to `guard.__call__`."
53+
"You must pass in either `text` or `messages` to `guard.__call__`."
5454
)
5555

5656
if not instructions:
@@ -65,14 +65,14 @@ def chat_prompt(
6565
def litellm_messages(
6666
prompt: Optional[str],
6767
instructions: Optional[str] = None,
68-
msg_history: Optional[List[Dict]] = None,
68+
messages: Optional[List[Dict]] = None,
6969
) -> List[Dict[str, str]]:
7070
"""Prepare messages for LiteLLM."""
71-
if msg_history:
72-
return msg_history
71+
if messages:
72+
return messages
7373
if prompt is None:
7474
raise PromptCallableException(
75-
"Either `text` or `msg_history` required for `guard.__call__`."
75+
"Either `text` or `messages` required for `guard.__call__`."
7676
)
7777

7878
if instructions:
@@ -143,8 +143,7 @@ def _invoke_llm(
143143
self,
144144
text: Optional[str] = None,
145145
model: str = "gpt-3.5-turbo",
146-
instructions: Optional[str] = None,
147-
msg_history: Optional[List[Dict]] = None,
146+
messages: Optional[List[Dict]] = None,
148147
*args,
149148
**kwargs,
150149
) -> LLMResponse:
@@ -170,9 +169,9 @@ def _invoke_llm(
170169
"The `litellm` package is not installed. "
171170
"Install with `pip install litellm`"
172171
) from e
173-
if text is not None or instructions is not None or msg_history is not None:
172+
if messages is not None:
174173
messages = litellm_messages(
175-
prompt=text, instructions=instructions, msg_history=msg_history
174+
prompt=text, messages=messages
176175
)
177176
kwargs["messages"] = messages
178177

@@ -592,7 +591,7 @@ async def invoke_llm(
592591
self,
593592
text: Optional[str] = None,
594593
instructions: Optional[str] = None,
595-
msg_history: Optional[List[Dict]] = None,
594+
messages: Optional[List[Dict]] = None,
596595
*args,
597596
**kwargs,
598597
):
@@ -619,11 +618,11 @@ async def invoke_llm(
619618
"Install with `pip install litellm`"
620619
) from e
621620

622-
if text is not None or instructions is not None or msg_history is not None:
621+
if text is not None or instructions is not None or messages is not None:
623622
messages = litellm_messages(
624623
prompt=text,
625624
instructions=instructions,
626-
msg_history=msg_history,
625+
messages=messages,
627626
)
628627
kwargs["messages"] = messages
629628

guardrails/run/async_runner.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from guardrails.actions.reask import NonParseableReAsk, ReAsk
2323
from guardrails.telemetry import trace_async_call, trace_async_step
2424

25+
from guardrails.constants import fail_status
26+
from guardrails.prompt import Prompt
2527

2628
class AsyncRunner(Runner):
2729
def __init__(
@@ -339,29 +341,37 @@ async def prepare_messages(
339341
async def validate_messages(
340342
self, call_log: Call, messages: MessageHistory, attempt_number: int
341343
):
342-
msg_str = messages_string(messages)
343-
inputs = Inputs(
344-
llm_output=msg_str,
345-
)
346-
iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
347-
call_log.iterations.insert(0, iteration)
348-
value, _metadata = await validator_service.async_validate(
349-
value=msg_str,
350-
metadata=self.metadata,
351-
validator_map=self.validation_map,
352-
iteration=iteration,
353-
disable_tracer=self._disable_tracer,
354-
path="messages",
355-
)
356-
validated_messages = validator_service.post_process_validation(
357-
value, attempt_number, iteration, OutputTypes.STRING
358-
)
359-
validated_messages = cast(str, validated_messages)
344+
for msg in messages:
345+
content = (
346+
msg["content"].source
347+
if isinstance(msg["content"], Prompt)
348+
else msg["content"]
349+
)
350+
inputs = Inputs(
351+
llm_output=content,
352+
)
353+
iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
354+
call_log.iterations.insert(0, iteration)
355+
value, _metadata = await validator_service.async_validate(
356+
value=content,
357+
metadata=self.metadata,
358+
validator_map=self.validation_map,
359+
iteration=iteration,
360+
disable_tracer=self._disable_tracer,
361+
path="messages",
362+
)
360363

361-
iteration.outputs.validation_response = validated_messages
362-
if isinstance(validated_messages, ReAsk):
363-
raise ValidationError(
364-
f"Messages validation failed: " f"{validated_messages}"
364+
validated_msg = validator_service.post_process_validation(
365+
value, attempt_number, iteration, OutputTypes.STRING
365366
)
366-
if validated_messages != msg_str:
367-
raise ValidationError("Messages validation failed")
367+
368+
iteration.outputs.validation_response = validated_msg
369+
370+
if isinstance(validated_msg, ReAsk):
371+
raise ValidationError(f"Messages validation failed: {validated_msg}")
372+
elif not validated_msg or iteration.status == fail_status:
373+
raise ValidationError("Messages validation failed")
374+
375+
msg["content"] = cast(str, validated_msg)
376+
377+
return messages # type: ignore

guardrails/run/runner.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,22 +287,17 @@ def step(
287287
def validate_messages(
288288
self, call_log: Call, messages: MessageHistory, attempt_number: int
289289
) -> None:
290-
msg_str = messages_string(messages)
291-
inputs = Inputs(
292-
llm_output=msg_str,
293-
)
294-
iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
295-
call_log.iterations.insert(0, iteration)
296-
297-
validated_msgs = ""
298-
299290
for msg in messages:
300291
content = (
301292
msg["content"].source
302293
if isinstance(msg["content"], Prompt)
303294
else msg["content"]
304295
)
305-
296+
inputs = Inputs(
297+
llm_output=content,
298+
)
299+
iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
300+
call_log.iterations.insert(0, iteration)
306301
value, _metadata = validator_service.validate(
307302
value=content,
308303
metadata=self.metadata,
@@ -315,16 +310,15 @@ def validate_messages(
315310
validated_msg = validator_service.post_process_validation(
316311
value, attempt_number, iteration, OutputTypes.STRING
317312
)
313+
314+
iteration.outputs.validation_response = validated_msg
318315

319316
if isinstance(validated_msg, ReAsk):
320-
raise ValidationError(f"Message validation failed: {validated_msg}")
317+
raise ValidationError(f"Messages validation failed: {validated_msg}")
321318
elif not validated_msg or iteration.status == fail_status:
322-
raise ValidationError("Message validation failed")
319+
raise ValidationError("Messages validation failed")
323320

324321
msg["content"] = cast(str, validated_msg)
325-
validated_msgs += validated_msg
326-
327-
iteration.outputs.validation_response = validated_msgs
328322

329323
return messages # type: ignore
330324

tests/integration_tests/mock_llm_outputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def _invoke_llm(
128128
}
129129

130130
try:
131+
if msg_history:
132+
key = (msg_history[0]["content"], msg_history[1]["content"])
133+
print("=========trying key", key)
134+
out_text = mock_llm_responses[key]
135+
print("========found out text", out_text)
131136
if prompt and instructions and not msg_history:
132137
out_text = mock_llm_responses[(prompt, instructions)]
133138
elif msg_history and not prompt and not instructions:

tests/integration_tests/test_formatters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Foo(BaseModel):
2525
bez: List[str]
2626

2727
g = Guard.from_pydantic(Foo, output_formatter="jsonformer")
28-
response = g(model.generate, tokenizer=tokenizer, prompt="test")
28+
response = g(model.generate, tokenizer=tokenizer, messages=[{"content": "test","role": "user"}])
2929
validated_output = response.validated_output
3030
assert isinstance(validated_output, dict)
3131
assert "bar" in validated_output
@@ -45,7 +45,7 @@ class Foo(BaseModel):
4545
bez: List[str]
4646

4747
g = Guard.from_pydantic(Foo, output_formatter="jsonformer")
48-
response = g(model, prompt="Sample:")
48+
response = g(model, messages=[{"content": "Sample:","role": "user"}])
4949
validated_output = response.validated_output
5050
assert isinstance(validated_output, dict)
5151
assert "bar" in validated_output

tests/integration_tests/test_parsing.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ def test_parsing_reask(mocker):
4949
assert call.iterations.length == 2
5050

5151
# For orginal prompt and output
52-
assert call.compiled_prompt == pydantic.PARSING_COMPILED_PROMPT
52+
assert call.compiled_messages[0]["content"]._source == pydantic.PARSING_COMPILED_PROMPT
5353
assert call.iterations.first.raw_output == pydantic.PARSING_UNPARSEABLE_LLM_OUTPUT
5454
assert call.iterations.first.guarded_output is None
5555

5656
# For re-asked prompt and output
57-
assert call.iterations.last.inputs.prompt == gd.Prompt(
57+
assert call.iterations.last.inputs.messages[1]["content"] == gd.Prompt(
5858
pydantic.PARSING_COMPILED_REASK
5959
)
6060
# Same as above
61-
assert call.reask_prompts.last == pydantic.PARSING_COMPILED_REASK
61+
assert call.reask_messages[0][1]["content"]._source == pydantic.PARSING_COMPILED_REASK
6262
assert call.raw_outputs.last == pydantic.PARSING_EXPECTED_LLM_OUTPUT
6363
assert call.guarded_output == pydantic.PARSING_EXPECTED_OUTPUT
6464

@@ -83,7 +83,8 @@ async def test_async_parsing_reask(mocker):
8383
]
8484

8585
guard = gd.AsyncGuard.from_pydantic(
86-
output_class=pydantic.PersonalDetails, prompt=pydantic.PARSING_INITIAL_PROMPT
86+
output_class=pydantic.PersonalDetails,
87+
messages=[{"role": "user", "content": pydantic.PARSING_INITIAL_PROMPT}],
8788
)
8889

8990
final_output = await guard(
@@ -100,17 +101,17 @@ async def test_async_parsing_reask(mocker):
100101
assert call.iterations.length == 2
101102

102103
# For orginal prompt and output
103-
assert call.compiled_prompt == pydantic.PARSING_COMPILED_PROMPT
104+
assert call.compiled_messages[0]["content"]._source == pydantic.PARSING_COMPILED_PROMPT
104105
assert call.iterations.first.raw_output == pydantic.PARSING_UNPARSEABLE_LLM_OUTPUT
105106
assert call.iterations.first.guarded_output is None
106107

107108
# For re-asked prompt and output
108109

109-
assert call.iterations.last.inputs.prompt == gd.Prompt(
110+
assert call.iterations.last.inputs.messages[1]["content"] == gd.Prompt(
110111
pydantic.PARSING_COMPILED_REASK
111112
)
112113
# Same as above
113-
assert call.reask_prompts.last == pydantic.PARSING_COMPILED_REASK
114+
assert call.reask_messages[0][1]["content"]._source == pydantic.PARSING_COMPILED_REASK
114115
assert call.raw_outputs.last == pydantic.PARSING_EXPECTED_LLM_OUTPUT
115116
assert call.guarded_output == pydantic.PARSING_EXPECTED_OUTPUT
116117

@@ -123,7 +124,7 @@ def test_reask_prompt_instructions(mocker):
123124
"""
124125

125126
mocker.patch(
126-
"guardrails.llm_providers.OpenAIChatCallable._invoke_llm",
127+
"guardrails.llm_providers.LiteLLMCallable._invoke_llm",
127128
return_value=LLMResponse(
128129
output=string.MSG_LLM_OUTPUT_CORRECT,
129130
prompt_token_count=123,
@@ -144,7 +145,7 @@ def always_fail(value: str, metadata: Dict) -> ValidationResult:
144145

145146
guard.parse(
146147
llm_output="Tomato Cheese Pizza",
147-
llm_api=openai.chat.completions.create,
148+
model="gpt-3.5-turbo",
148149
messages=[
149150
{"role": "system", "content": "Some content"},
150151
{"role": "user", "content": "Some prompt"},

0 commit comments

Comments
 (0)