From 7ec22cd973acb850b7c84ade4f84424622b0c8a2 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Wed, 20 Aug 2025 11:47:23 -0400 Subject: [PATCH] copy reqs during validation; add test --- mellea/stdlib/requirement.py | 15 ++++++++++----- test/stdlib_basics/test_requirement.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index ad4709fa..1aa849fa 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -3,6 +3,7 @@ import inspect import re from collections.abc import Callable +from copy import copy from typing import Any, overload from mellea.backends import ( @@ -97,6 +98,9 @@ def __init__( self.validation_fn = validation_fn self.check_only = check_only + # Used for validation. Do not manually populate. + self._output: str | None = None + def validate( self, backend: Backend, @@ -117,17 +121,18 @@ def validate( assert isinstance(last_output, ModelOutputThunk), ( " Context has no appropriate last output" ) - self._output = last_output.value # type: ignore + + # Create a copy of the requirement that holds the output + # and its template gets populated with the output correctly. + req_copy = copy(self) + req_copy._output = last_output.value llm_as_a_judge_result = backend.generate_from_context( - self, + req_copy, ctx, format=format, model_options=model_options, generate_logs=generate_logs, ) - # This is crucial, because requirements can get reused; - # this also means requirements are not thread-safe. - self._output = None return ValidationResult( result=self.output_to_bool(llm_as_a_judge_result), reason=llm_as_a_judge_result.value, diff --git a/test/stdlib_basics/test_requirement.py b/test/stdlib_basics/test_requirement.py index aa7845d4..12af105a 100644 --- a/test/stdlib_basics/test_requirement.py +++ b/test/stdlib_basics/test_requirement.py @@ -1,11 +1,19 @@ import pytest from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.requirement import simple_validate -from mellea.stdlib.session import SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.stdlib.session import SimpleContext, start_session ctx = SimpleContext() ctx.insert(ModelOutputThunk("test")) +def test_llmaj_validation_req_output_field(): + m = start_session(ctx=ctx) + req = Requirement("Must output test.") + assert req._output is None + + _ = req.validate(m.backend,ctx=ctx) + assert req._output is None, "requirement's output shouldn't be updated during/after validation" + def test_simple_validate_bool(): validation_func = simple_validate(lambda x: False, reason="static reason") val_result = validation_func(ctx)