Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions mellea/stdlib/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import re
from collections.abc import Callable
from copy import copy
from typing import Any, overload

from mellea.backends import (
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
self.validation_fn = validation_fn
self.check_only = check_only

# Used for validation. Do not manually populate.
self._output: str | None = None

def validate(
self,
backend: Backend,
Expand All @@ -117,17 +121,18 @@ def validate(
assert isinstance(last_output, ModelOutputThunk), (
" Context has no appropriate last output"
)
self._output = last_output.value # type: ignore

# Create a copy of the requirement that holds the output
# and its template gets populated with the output correctly.
req_copy = copy(self)
req_copy._output = last_output.value
llm_as_a_judge_result = backend.generate_from_context(
self,
req_copy,
ctx,
format=format,
model_options=model_options,
generate_logs=generate_logs,
)
# This is crucial, because requirements can get reused;
# this also means requirements are not thread-safe.
self._output = None
return ValidationResult(
result=self.output_to_bool(llm_as_a_judge_result),
reason=llm_as_a_judge_result.value,
Expand Down
12 changes: 10 additions & 2 deletions test/stdlib_basics/test_requirement.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import pytest
from mellea.stdlib.base import ModelOutputThunk
from mellea.stdlib.requirement import simple_validate
from mellea.stdlib.session import SimpleContext
from mellea.stdlib.requirement import Requirement, simple_validate
from mellea.stdlib.session import SimpleContext, start_session

ctx = SimpleContext()
ctx.insert(ModelOutputThunk("test"))

def test_llmaj_validation_req_output_field():
m = start_session(ctx=ctx)
req = Requirement("Must output test.")
assert req._output is None

_ = req.validate(m.backend,ctx=ctx)
assert req._output is None, "requirement's output shouldn't be updated during/after validation"

def test_simple_validate_bool():
validation_func = simple_validate(lambda x: False, reason="static reason")
val_result = validation_func(ctx)
Expand Down