Skip to content

Commit e8eb7f5

Browse files
committed
typing and lint
1 parent bc7ff07 commit e8eb7f5

File tree

8 files changed

+57
-38
lines changed

8 files changed

+57
-38
lines changed

guardrails/applications/text2sql.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@ def __init__(
7070
rail_spec: Optional[str] = None,
7171
rail_params: Optional[Dict] = None,
7272
example_formatter: Callable = example_formatter,
73-
reask_prompt: str = REASK_PROMPT,
73+
reask_messages: list[Dict[str, str]] = [
74+
{
75+
"role": "user",
76+
"content": REASK_PROMPT,
77+
}
78+
],
7479
llm_api: Optional[Callable] = None,
7580
llm_api_kwargs: Optional[Dict] = None,
7681
num_relevant_examples: int = 2,
@@ -108,7 +113,7 @@ def __init__(
108113
schema_file,
109114
rail_spec,
110115
rail_params,
111-
reask_prompt,
116+
reask_messages,
112117
)
113118

114119
# Initialize the document store.
@@ -122,7 +127,12 @@ def _init_guard(
122127
schema_file: Optional[str] = None,
123128
rail_spec: Optional[str] = None,
124129
rail_params: Optional[Dict] = None,
125-
reask_prompt: str = REASK_PROMPT,
130+
reask_messages: list[Dict[str, str]] = [
131+
{
132+
"role": "user",
133+
"content": REASK_PROMPT,
134+
}
135+
],
126136
):
127137
# Initialize the Guard class
128138
if rail_spec is None:
@@ -140,7 +150,7 @@ def _init_guard(
140150
rail_spec_str = Template(rail_spec_str).safe_substitute(**rail_params)
141151

142152
guard = Guard.from_rail_string(rail_spec_str)
143-
guard._exec_opts.reask_prompt = reask_prompt
153+
guard._exec_opts.reask_messages = reask_messages
144154

145155
return guard
146156

guardrails/classes/history/call.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,23 @@ def prompt_params(self) -> Optional[Dict]:
7878
return self.inputs.prompt_params
7979

8080
@property
81-
def messages(self) -> Optional[Messages]:
81+
def messages(self) -> Optional[Union[Messages, list[dict[str, str]]]]:
8282
"""The messages as provided by the user when initializing or calling the
8383
Guard."""
8484
return self.inputs.messages
8585

8686
@property
87-
def compiled_messages(self) -> Optional[str]:
87+
def compiled_messages(self) -> Optional[list[dict[str, str]]]:
8888
"""The initial compiled messages that were passed to the LLM on the
8989
first call."""
9090
if self.iterations.empty():
9191
return None
92-
initial_inputs = self.iterations.first.inputs
93-
messages: Messages = initial_inputs.messages
92+
initial_inputs = self.iterations.first.inputs # type: ignore
93+
messages = initial_inputs.messages
9494
prompt_params = initial_inputs.prompt_params or {}
9595
compiled_messages = []
96+
if messages is None:
97+
return None
9698
for message in messages:
9799
content = message["content"].format(**prompt_params)
98100
if isinstance(content, (Prompt, Instructions)):
@@ -116,11 +118,11 @@ def reask_messages(self) -> Stack[Messages]:
116118
reasks = self.iterations.copy()
117119
initial_messages = reasks.first
118120
reasks.remove(initial_messages) # type: ignore
119-
initial_inputs = self.iterations.first.inputs
121+
initial_inputs = self.iterations.first.inputs # type: ignore
120122
prompt_params = initial_inputs.prompt_params or {}
121123
compiled_reasks = []
122124
for reask in reasks:
123-
messages: Messages = reask.inputs.messages
125+
messages = reask.inputs.messages
124126

125127
if messages is None:
126128
compiled_reasks.append(None)
@@ -136,7 +138,7 @@ def reask_messages(self) -> Stack[Messages]:
136138
"content": content,
137139
}
138140
)
139-
compiled_reasks.append(compiled_messages)
141+
compiled_reasks.append(compiled_messages)
140142
return Stack(*compiled_reasks)
141143

142144
return Stack()

guardrails/classes/history/iteration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from guardrails.classes.history.outputs import Outputs
1313
from guardrails.classes.generic.arbitrary_model import ArbitraryModel
1414
from guardrails.logger import get_scope_handler
15-
from guardrails.prompt.prompt import Prompt
15+
from guardrails.prompt import Prompt, Instructions
1616
from guardrails.classes.validation.validator_logs import ValidatorLogs
1717
from guardrails.actions.reask import ReAsk
1818
from guardrails.classes.validation.validation_result import ErrorSpan
@@ -189,7 +189,7 @@ def status(self) -> str:
189189
@property
190190
def rich_group(self) -> Group:
191191
def create_messages_table(
192-
messages: Optional[List[Dict[str, Prompt]]],
192+
messages: Optional[List[Dict[str, Union[str, Prompt, Instructions]]]],
193193
) -> Union[str, Table]:
194194
if messages is None:
195195
return "No messages."
@@ -198,11 +198,11 @@ def create_messages_table(
198198
table.add_column("Content")
199199

200200
for msg in messages:
201-
table.add_row(str(msg["role"]), msg["content"])
201+
table.add_row(str(msg["role"]), msg["content"]) # type: ignore
202202

203203
return table
204204

205-
table = create_messages_table(self.inputs.messages)
205+
table = create_messages_table(self.inputs.messages) # type: ignore
206206

207207
return Group(
208208
Panel(table, title="Messages", style="on #E7DFEB"),

guardrails/formatters/json_formatter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@ def fn(
106106
**kwargs,
107107
) -> str:
108108
prompt = ""
109-
for msg in messages:
109+
for msg in messages: # type: ignore
110110
prompt += msg["content"]
111111

112112
return json.dumps(
113113
Jsonformer(
114114
model=model.model,
115115
tokenizer=model.tokenizer,
116116
json_schema=self.output_schema,
117-
prompt=prompt
117+
prompt=prompt,
118118
)()
119119
)
120120

@@ -132,15 +132,15 @@ def fn(
132132
**kwargs,
133133
) -> str:
134134
prompt = ""
135-
for msg in messages:
135+
for msg in messages: # type: ignore
136136
prompt += msg["content"]
137-
137+
138138
return json.dumps(
139139
Jsonformer(
140140
model=model,
141141
tokenizer=tokenizer,
142142
json_schema=self.output_schema,
143-
prompt=prompt
143+
prompt=prompt,
144144
)()
145145
)
146146

guardrails/llm_providers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@
3232

3333

3434
# todo fix circular import
35-
def messages_string(messages: MessageHistory) -> str:
35+
def messages_string(
36+
messages: Union[list[dict[str, Union[str, Prompt, Instructions]]], MessageHistory],
37+
) -> str:
3638
messages_copy = ""
3739
for msg in messages:
3840
content = (
39-
msg["content"].source
41+
msg["content"].source # type: ignore
4042
if isinstance(msg["content"], Prompt)
41-
else msg["content"]
43+
or isinstance(msg["content"], Instructions) # type: ignore
44+
else msg["content"] # type: ignore
4245
)
4346
messages_copy += content
4447
return messages_copy
@@ -269,7 +272,9 @@ def _invoke_llm(
269272
self,
270273
model_generate: Any,
271274
*args,
272-
messages: list[dict[str, Union[str, Prompt, Instructions]]],
275+
messages: Union[
276+
list[dict[str, Union[str, Prompt, Instructions]]], MessageHistory
277+
],
273278
**kwargs,
274279
) -> LLMResponse:
275280
try:

guardrails/prompt/messages.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import re
44
from string import Template
5-
from typing import Dict, List, Optional
6-
5+
from typing import Dict, List, Optional, Union
76

7+
from guardrails.prompt import Prompt, Instructions
88
from guardrails.classes.templating.namespace_template import NamespaceTemplate
99
from guardrails.utils.constants import constants
1010
from guardrails.utils.templating_utils import get_template_variables
@@ -13,7 +13,7 @@
1313
class Messages:
1414
def __init__(
1515
self,
16-
source: List[Dict[str, str]],
16+
source: List[Dict[str, Union[str, Prompt, Instructions]]],
1717
output_schema: Optional[str] = None,
1818
*,
1919
xml_output_schema: Optional[str] = None,
@@ -71,9 +71,7 @@ def format(
7171
filtered_kwargs = {k: v for k, v in kwargs.items() if k in vars}
7272

7373
# Return another instance of the class with the formatted message.
74-
formatted_message = Template(msg_str).safe_substitute(
75-
**filtered_kwargs
76-
)
74+
formatted_message = Template(msg_str).safe_substitute(**filtered_kwargs)
7775
formatted_messages.append(
7876
{"role": message["role"], "content": formatted_message}
7977
)

guardrails/run/runner.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
)
1717
from guardrails.logger import set_scope
1818
from guardrails.prompt import Prompt
19-
from guardrails.run.utils import messages_source, messages_string
19+
from guardrails.prompt.messages import Messages
20+
from guardrails.run.utils import messages_source
2021
from guardrails.schema.rail_schema import json_schema_to_rail_output
2122
from guardrails.schema.validator import schema_validation
2223
from guardrails.hub_telemetry.hub_tracing import trace
@@ -35,6 +36,7 @@
3536
from guardrails.actions.reask import NonParseableReAsk, ReAsk, introspect
3637
from guardrails.telemetry import trace_call, trace_step
3738

39+
3840
class Runner:
3941
"""Runner class that calls an LLM API with a prompt, and performs input and
4042
output validation.
@@ -293,9 +295,11 @@ def validate_messages(
293295
else msg["content"]
294296
)
295297
inputs = Inputs(
296-
llm_output=content,
297-
)
298-
iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs)
298+
llm_output=content,
299+
)
300+
iteration = Iteration(
301+
call_id=call_log.id, index=attempt_number, inputs=inputs
302+
)
299303
call_log.iterations.insert(0, iteration)
300304
value, _metadata = validator_service.validate(
301305
value=content,
@@ -309,7 +313,7 @@ def validate_messages(
309313
validated_msg = validator_service.post_process_validation(
310314
value, attempt_number, iteration, OutputTypes.STRING
311315
)
312-
316+
313317
iteration.outputs.validation_response = validated_msg
314318

315319
if isinstance(validated_msg, ReAsk):
@@ -501,7 +505,7 @@ def prepare_to_loop(
501505
prompt_params: Optional[Dict] = None,
502506
) -> Tuple[
503507
Dict[str, Any],
504-
Optional[List[Dict]],
508+
Optional[Union[List[Dict], Messages]],
505509
]:
506510
"""Prepare to loop again."""
507511
prompt_params = prompt_params or {}

guardrails/utils/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def generate_test_artifacts(
4949
ext = f"_reask_{i}"
5050

5151
# Save the compiled prompt.
52-
compiled_prompt = logs.inputs.prompt
52+
compiled_messages = logs.inputs.messages
5353
with open(
5454
os.path.join(artifact_dir, f"compiled_prompt_{on_fail_type}{ext}.txt"), "w"
5555
) as f:
56-
f.write(str(compiled_prompt or ""))
56+
f.write(str(compiled_messages or ""))
5757

5858
# Save the llm output.
5959
llm_output = logs.raw_output

0 commit comments

Comments
 (0)