Skip to content

Commit 4819407

Browse files
authored
make reqs thread safe and add test that checks for state changes (#83)
1 parent dda2488 commit 4819407

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

mellea/stdlib/requirement.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import re
55
from collections.abc import Callable
6+
from copy import copy
67
from typing import Any, overload
78

89
from mellea.backends import (
@@ -97,6 +98,9 @@ def __init__(
9798
self.validation_fn = validation_fn
9899
self.check_only = check_only
99100

101+
# Used for validation. Do not manually populate.
102+
self._output: str | None = None
103+
100104
def validate(
101105
self,
102106
backend: Backend,
@@ -117,17 +121,18 @@ def validate(
117121
assert isinstance(last_output, ModelOutputThunk), (
118122
" Context has no appropriate last output"
119123
)
120-
self._output = last_output.value # type: ignore
124+
125+
# Create a copy of the requirement that holds the output
126+
# and its template gets populated with the output correctly.
127+
req_copy = copy(self)
128+
req_copy._output = last_output.value
121129
llm_as_a_judge_result = backend.generate_from_context(
122-
self,
130+
req_copy,
123131
ctx,
124132
format=format,
125133
model_options=model_options,
126134
generate_logs=generate_logs,
127135
)
128-
# This is crucial, because requirements can get reused;
129-
# this also means requirements are not thread-safe.
130-
self._output = None
131136
return ValidationResult(
132137
result=self.output_to_bool(llm_as_a_judge_result),
133138
reason=llm_as_a_judge_result.value,

test/stdlib_basics/test_requirement.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import pytest
22
from mellea.stdlib.base import ModelOutputThunk
3-
from mellea.stdlib.requirement import simple_validate
4-
from mellea.stdlib.session import SimpleContext
3+
from mellea.stdlib.requirement import Requirement, simple_validate
4+
from mellea.stdlib.session import SimpleContext, start_session
55

66
ctx = SimpleContext()
77
ctx.insert(ModelOutputThunk("test"))
88

9+
def test_llmaj_validation_req_output_field():
10+
m = start_session(ctx=ctx)
11+
req = Requirement("Must output test.")
12+
assert req._output is None
13+
14+
_ = req.validate(m.backend,ctx=ctx)
15+
assert req._output is None, "requirement's output shouldn't be updated during/after validation"
16+
917
def test_simple_validate_bool():
1018
validation_func = simple_validate(lambda x: False, reason="static reason")
1119
val_result = validation_func(ctx)

0 commit comments

Comments
 (0)