Skip to content

Commit 80b83fd

Browse files
committed
composable requirements
1 parent f08d2ec commit 80b83fd

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

mellea/stdlib/requirement.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ def format_for_llm(self) -> TemplateRepresentation | str:
156156
template_order=["*", "Requirement"],
157157
)
158158

159+
def __and__(self, other):
160+
return ConjunctiveRequirement([self,other])
161+
def __or__(self, other):
162+
return DisjunctiveRequirement([self,other])
163+
def __not__(self):
164+
return NegativeRequirement(self)
165+
159166

160167
class LLMaJRequirement(Requirement):
161168
"""A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored."""
@@ -178,6 +185,98 @@ def __init__(self, description: str, alora: Alora | None = None):
178185
self.alora = alora
179186

180187

188+
class ConjunctiveRequirement(Requirement):
189+
def __init__(self, requirements: list[Requirement],):
190+
self.requirements = requirements
191+
192+
@property
193+
def description(self):
194+
return "\n* ".join(
195+
["Satisfy all of these requirements:"] + \
196+
[r.description for r in self.requirements])
197+
198+
def validate(self, *args, **kwargs):
199+
results = [r.validate(*args, **kwargs) for r in self.requirements]
200+
return ValidationResult(
201+
result = all(results),
202+
reason = "\n* ".join(
203+
["These requirements are not satisfied:"]+
204+
[r.reason for r in results if not r]),
205+
score = max([r.score for r in results if not r]))
206+
207+
def __and__(self, other):
208+
match other:
209+
case ConjunctiveRequirement():
210+
ConjunctiveRequirement(self.requirements+other.requirements)
211+
case Requirement():
212+
ConjunctiveRequirement(self.requirements+[other])
213+
214+
def __or__(self, other):
215+
return DisjunctiveRequirement([self,other])
216+
def __not__(self):
217+
return NegativeRequirement(self)
218+
219+
220+
221+
class DisjunctiveRequirement(Requirement):
222+
def __init__(self, requirements: list[Requirement],):
223+
self.requirements = requirements
224+
225+
@property
226+
def description(self):
227+
return "\n* ".join(
228+
["Satisfy at least one of these requirements:"] + \
229+
[r.description for r in self.requirements])
230+
231+
def validate(self, *args, **kwargs):
232+
results = [r.validate(*args, **kwargs) for r in self.requirements]
233+
return ValidationResult(
234+
result = any(results),
235+
reason = "\n* ".join(
236+
["These requirements are satisfied:"]+
237+
[r.reason for r in results if r]),
238+
score = min([r.score for r in results if not r]))
239+
240+
def __and__(self, other):
241+
return ConjunctiveRequirement([self,other])
242+
def __or__(self, other):
243+
match other:
244+
case DisjunctiveRequirement():
245+
DisjunctiveRequirement(self.requirements+other.requirements)
246+
case Requirement():
247+
DisjunctiveRequirement(self.requirements+[other])
248+
def __not__(self):
249+
return NegativeRequirement(self)
250+
251+
252+
class NegativeRequirement(Requirement):
253+
def __init__(self, requirement: Requirement,):
254+
self.requirement = requirement
255+
256+
@property
257+
def description(self):
258+
return f"Do not satisfy this requirement: {self.requirement.description}"
259+
260+
def __getattr__(self, name):
261+
# delegate lookup to self.requirement
262+
return getattr(self.requirement, name)
263+
264+
def validate(self, *args, **kwargs):
265+
result = self.requirement.validate(*args, **kwargs)
266+
return ValidationResult(
267+
result = not result,
268+
reason = result.reason,
269+
# score = ???
270+
)
271+
272+
def __and__(self, other):
273+
return ConjunctiveRequirement([self,other])
274+
def __or__(self, other):
275+
return DisjunctiveRequirement([self,other])
276+
def __not__(self):
277+
return self.requirement
278+
279+
181280
def reqify(r: str | Requirement) -> Requirement:
182281
"""Maps strings to Requirements.
183282

0 commit comments

Comments
 (0)