Skip to content

Commit 36ad134

Browse files
use validation results to communicate the outcome of requirement validation (#54)
* validate returns a ValidationResult * tests now working for ValidationResult (openai backend mypy fix) * overload simple validate for easier dynamic reasons --------- Co-authored-by: Hendrik Strobelt <[email protected]>
1 parent bde69a9 commit 36ad134

File tree

7 files changed

+214
-94
lines changed

7 files changed

+214
-94
lines changed

mellea/backends/openai.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from transformers.tokenization_utils import PreTrainedTokenizer
1919

2020
import mellea.backends.model_ids as model_ids
21-
from cli.serve.models import ChatCompletionMessage
2221
from mellea.backends import BaseModelSubclass
2322
from mellea.backends.aloras import Alora, AloraBackendMixin
2423
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter

mellea/stdlib/requirement.py

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Requirements are a special type of Component used as input to the "validate" step in Instruct/Validate/Repair design patterns."""
22

3+
import inspect
34
import re
45
from collections.abc import Callable
5-
from typing import Any
6+
from typing import Any, overload
67

78
from mellea.backends import (
89
Backend,
@@ -35,13 +36,48 @@ def default_output_to_bool(x: CBlock | str) -> bool:
3536
return False
3637

3738

39+
class ValidationResult:
40+
"""ValidationResults store the output of a Requirement's validation. They can be used to return additional info from validation functions, which is useful for sampling/repairing."""
41+
42+
def __init__(
43+
self, result: bool, *, reason: str | None = None, score: float | None = None
44+
):
45+
"""The result of a requirement's validation.
46+
47+
A ValidationResult's result field always contains a definitive pass/fail. The other fields can be used to communicate additional information about that result.
48+
49+
Args:
50+
result: a boolean that is true if the requirement passed
51+
reason: a reason for the result
52+
score: if your validator gives you a score back, you can add this as metadata
53+
"""
54+
self._result = result
55+
self._reason = reason
56+
self._score = score
57+
58+
@property
59+
def reason(self) -> str | None:
60+
return self._reason
61+
62+
@property
63+
def score(self) -> float | None:
64+
return self._score
65+
66+
def as_bool(self) -> bool:
67+
""""""
68+
return self._result
69+
70+
def __bool__(self) -> bool:
71+
return self.as_bool()
72+
73+
3874
class Requirement(Component):
3975
"""Requirements are a special type of Component used as input to the Validate step in Instruct/Validate/Repair patterns."""
4076

4177
def __init__(
4278
self,
4379
description: str | None = None,
44-
validation_fn: Callable[[Context], Any] | None = None,
80+
validation_fn: Callable[[Context], ValidationResult] | None = None,
4581
*,
4682
output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool,
4783
check_only: bool = False,
@@ -69,12 +105,11 @@ def validate(
69105
format: type[BaseModelSubclass] | None = None,
70106
model_options: dict | None = None,
71107
generate_logs: list[GenerateLog] | None = None,
72-
) -> tuple[Any, bool]:
108+
) -> ValidationResult:
73109
"""Chooses the appropriate validation strategy and applies that strategy."""
74110
if self.validation_fn is not None:
75111
# Python validation strategy
76-
result = self.validation_fn(ctx)
77-
return result, bool(result)
112+
return self.validation_fn(ctx)
78113
else:
79114
# LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch.
80115
assert self.output_to_bool is not None
@@ -93,7 +128,10 @@ def validate(
93128
# This is crucial, because requirements can get reused;
94129
# this also means requirements are not thread-safe.
95130
self._output = None
96-
return llm_as_a_judge_result, self.output_to_bool(llm_as_a_judge_result)
131+
return ValidationResult(
132+
result=self.output_to_bool(llm_as_a_judge_result),
133+
reason=llm_as_a_judge_result.value,
134+
)
97135

98136
def parts(self):
99137
"""Returns all of the constituent parts of a Requirement."""
@@ -158,7 +196,21 @@ def check(*args, **kwargs) -> Requirement:
158196
return Requirement(*args, **kwargs, check_only=True)
159197

160198

161-
def simple_validate(fn: Callable[[str], bool]) -> Callable[[Context], bool]:
199+
@overload
200+
def simple_validate(
201+
fn: Callable[[str], tuple[bool, str]],
202+
) -> Callable[[Context], ValidationResult]: ...
203+
204+
205+
@overload
206+
def simple_validate(
207+
fn: Callable[[str], bool], *, reason: str | None = None
208+
) -> Callable[[Context], ValidationResult]: ...
209+
210+
211+
def simple_validate(
212+
fn: Callable[[str], Any], *, reason: str | None = None
213+
) -> Callable[[Context], ValidationResult]:
162214
"""Syntactic sugar for writing validation functions that only operate over the last output from the model (interpreted as a string).
163215
164216
This is useful when your validation logic only depends upon the most recent model output. For example:
@@ -170,15 +222,36 @@ def simple_validate(fn: Callable[[str], bool]) -> Callable[[Context], bool]:
170222
Important notes:
171223
- this operates over the more recent _model output_, not the most recent message.
172224
- Model outputs are sometimes parsed into more complex types (eg by a `Formatter.parse` call or an OutputProcessor). This validation logic will interpret the most recent output as a string, regardless of whether it has a more complex parsed representation.
225+
226+
Args:
227+
fn: the simple validation function that takes a string and returns either a bool or (bool, str)
228+
reason: only used if the provided function returns a bool; if the validation function fails, a static reason for that failure to give to the llm when repairing
173229
"""
174230

175-
def validate(ctx: Context) -> bool:
231+
def validate(ctx: Context) -> ValidationResult:
176232
o = ctx.last_output()
177233
if o is None or o.value is None:
178234
FancyLogger.get_logger().warn(
179235
"Last output of context was None. That might be a problem. We return validation as False to be able to continue..."
180236
)
181-
return False
182-
return fn(o.value)
237+
return ValidationResult(
238+
False
239+
) # Don't pass in the static reason since the function didn't run.
240+
241+
result = fn(o.value)
242+
243+
# Only confirm that the result conforms to the fn type requirements here. Functions can
244+
# declare return types and then deviate from them.
245+
246+
# Oneliner that checks the tuple actually contains (bool, str)
247+
if isinstance(result, tuple) and list(map(type, result)) == [bool, str]:
248+
return ValidationResult(result[0], reason=result[1])
249+
250+
elif type(result) is bool:
251+
return ValidationResult(result, reason=reason)
252+
253+
raise ValueError(
254+
f"function {fn.__name__} passed to simple_validate didn't return either bool or [bool, str]; returned {type(result)} instead"
255+
)
183256

184257
return validate

mellea/stdlib/sampling.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mellea.helpers.fancy_logger import FancyLogger
1010
from mellea.stdlib.base import CBlock, GenerateLog, ModelOutputThunk
1111
from mellea.stdlib.instruction import Instruction
12-
from mellea.stdlib.requirement import Requirement
12+
from mellea.stdlib.requirement import Requirement, ValidationResult
1313

1414

1515
class SamplingResult(CBlock):
@@ -21,15 +21,16 @@ def __init__(
2121
success: bool,
2222
*,
2323
sample_generations: list[ModelOutputThunk] | None = None,
24-
sample_validations: list[list[tuple[Requirement, bool]]] | None = None,
24+
sample_validations: list[list[tuple[Requirement, ValidationResult]]]
25+
| None = None,
2526
):
2627
"""Initialize a new instance of sampling results.
2728
2829
Args:
2930
result: The final output or result from applying the sampling strategy.
3031
success: A boolean indicating whether the operation was successful.
3132
sample_generations: A list containing intermediate generations produced during the process.
32-
sample_validations: For each generation a list of a requirement and a boolean value indicating whether the requirement was met.
33+
sample_validations: For each generation a list of tuples of a requirement and a validation result.
3334
"""
3435
super().__init__(value=result.value)
3536
self.result = result
@@ -45,7 +46,9 @@ class SamplingStrategy(abc.ABC):
4546
It allows setting custom validation and generation functions through properties.
4647
"""
4748

48-
validate: Callable[[list[Requirement], Any], list[bool]] | None = None
49+
# the function signature here matches that of m.validate
50+
validate: Callable[[list[Requirement], Any], list[ValidationResult]] | None = None
51+
4952
generate: (
5053
Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None
5154
) = None
@@ -75,14 +78,23 @@ def __init__(
7578
*,
7679
loop_budget: int = 1,
7780
repair: Callable[
78-
[Instruction, list[tuple[Requirement, bool]], list[Instruction]],
81+
[
82+
Instruction,
83+
list[tuple[Requirement, ValidationResult]],
84+
list[Instruction],
85+
],
7986
Instruction,
8087
] = lambda i, r, h_i: i,
8188
select_from_failure: Callable[
82-
[Instruction, list[ModelOutputThunk], list[list[tuple[Requirement, bool]]]],
89+
[
90+
Instruction,
91+
list[ModelOutputThunk],
92+
list[list[tuple[Requirement, ValidationResult]]],
93+
],
8394
ModelOutputThunk,
8495
] = lambda _, results, __: results[0],
85-
validate: Callable[[list[Requirement], Any], list[bool]] | None = None,
96+
validate: Callable[[list[Requirement], Any], list[ValidationResult]]
97+
| None = None,
8698
generate: (
8799
Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None
88100
) = None,
@@ -139,7 +151,7 @@ def sample(
139151
flog = FancyLogger.get_logger()
140152

141153
failed_results: list[ModelOutputThunk] = []
142-
failed_scores: list[list[tuple[Requirement, bool]]] = []
154+
failed_scores: list[list[tuple[Requirement, ValidationResult]]] = []
143155
failed_instructions: list[Instruction] = []
144156

145157
loop_count = 0
@@ -169,7 +181,7 @@ def sample(
169181
failed_scores.append(constraint_scores)
170182
failed_instructions.append(instruction)
171183

172-
if all(s[1] for s in constraint_scores):
184+
if all(bool(s[1]) for s in constraint_scores):
173185
flog.info("SUCCESS")
174186
return SamplingResult(
175187
result,
@@ -179,7 +191,7 @@ def sample(
179191
)
180192

181193
else:
182-
count_valid = len([s for s in constraint_scores if s[1]])
194+
count_valid = len([s for s in constraint_scores if bool(s[1])])
183195
flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}")
184196
# If we did not pass all constraints, update the instruction and try again.
185197
instruction = self.repair(

mellea/stdlib/session.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from mellea.stdlib.instruction import Instruction
3232
from mellea.stdlib.mify import mify
3333
from mellea.stdlib.mobject import MObjectProtocol
34-
from mellea.stdlib.requirement import Requirement, check, req
34+
from mellea.stdlib.requirement import Requirement, ValidationResult, check, req
3535
from mellea.stdlib.sampling import SamplingResult, SamplingStrategy
3636

3737

@@ -293,11 +293,10 @@ def validate(
293293
reqs: Requirement | list[Requirement],
294294
*,
295295
output: CBlock | None = None,
296-
return_full_validation_results: bool = False,
297296
format: type[BaseModelSubclass] | None = None,
298297
model_options: dict | None = None,
299298
generate_logs: list[GenerateLog] | None = None,
300-
) -> list[bool] | list[tuple[Any, bool]]:
299+
) -> list[ValidationResult]:
301300
"""Validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""
302301
# Turn a solitary requirement in to a list of requirements, and then reqify if needed.
303302
reqs = [reqs] if not isinstance(reqs, list) else reqs
@@ -309,18 +308,16 @@ def validate(
309308
validation_target_ctx.insert(output)
310309
rvs = []
311310
for requirement in reqs:
312-
req_v, req_satisfied = requirement.validate(
311+
val_result = requirement.validate(
313312
self.backend,
314313
validation_target_ctx,
315314
format=format,
316315
model_options=model_options,
317316
generate_logs=generate_logs,
318317
)
319-
rvs.append((req_v, req_satisfied))
320-
if return_full_validation_results:
321-
return rvs
322-
else:
323-
return [b for (_, b) in rvs]
318+
rvs.append(val_result)
319+
320+
return rvs
324321

325322
def req(self, *args, **kwargs):
326323
"""Shorthand for Requirement.__init__(...)."""

0 commit comments

Comments
 (0)