Skip to content

Commit 024d9a6

Browse files
committed
composable requirements
1 parent f92a286 commit 024d9a6

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

mellea/stdlib/requirement.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ def format_for_llm(self) -> TemplateRepresentation | str:
151151
template_order=["*", "Requirement"],
152152
)
153153

154+
def __and__(self, other):
155+
return ConjunctiveRequirement([self,other])
156+
def __or__(self, other):
157+
return DisjunctiveRequirement([self,other])
158+
def __not__(self):
159+
return NegativeRequirement(self)
160+
154161

155162
class LLMaJRequirement(Requirement):
156163
"""A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored."""
@@ -173,6 +180,98 @@ def __init__(self, description: str, alora: Alora | None = None):
173180
self.alora = alora
174181

175182

183+
class ConjunctiveRequirement(Requirement):
184+
def __init__(self, requirements: list[Requirement],):
185+
self.requirements = requirements
186+
187+
@property
188+
def description(self):
189+
return "\n* ".join(
190+
["Satisfy all of these requirements:"] + \
191+
[r.description for r in self.requirements])
192+
193+
def validate(self, *args, **kwargs):
194+
results = [r.validate(*args, **kwargs) for r in self.requirements]
195+
return ValidationResult(
196+
result = all(results),
197+
reason = "\n* ".join(
198+
["These requirements are not satisfied:"]+
199+
[r.reason for r in results if not r]),
200+
score = max([r.score for r in results if not r]))
201+
202+
def __and__(self, other):
203+
match other:
204+
case ConjunctiveRequirement():
205+
ConjunctiveRequirement(self.requirements+other.requirements)
206+
case Requirement():
207+
ConjunctiveRequirement(self.requirements+[other])
208+
209+
def __or__(self, other):
210+
return DisjunctiveRequirement([self,other])
211+
def __not__(self):
212+
return NegativeRequirement(self)
213+
214+
215+
216+
class DisjunctiveRequirement(Requirement):
217+
def __init__(self, requirements: list[Requirement],):
218+
self.requirements = requirements
219+
220+
@property
221+
def description(self):
222+
return "\n* ".join(
223+
["Satisfy at least one of these requirements:"] + \
224+
[r.description for r in self.requirements])
225+
226+
def validate(self, *args, **kwargs):
227+
results = [r.validate(*args, **kwargs) for r in self.requirements]
228+
return ValidationResult(
229+
result = any(results),
230+
reason = "\n* ".join(
231+
["These requirements are satisfied:"]+
232+
[r.reason for r in results if r]),
233+
score = min([r.score for r in results if not r]))
234+
235+
def __and__(self, other):
236+
return ConjunctiveRequirement([self,other])
237+
def __or__(self, other):
238+
match other:
239+
case DisjunctiveRequirement():
240+
DisjunctiveRequirement(self.requirements+other.requirements)
241+
case Requirement():
242+
DisjunctiveRequirement(self.requirements+[other])
243+
def __not__(self):
244+
return NegativeRequirement(self)
245+
246+
247+
class NegativeRequirement(Requirement):
248+
def __init__(self, requirement: Requirement,):
249+
self.requirement = requirement
250+
251+
@property
252+
def description(self):
253+
return f"Do not satisfy this requirement: {self.requirement.description}"
254+
255+
def __getattr__(self, name):
256+
# delegate lookup to self.requirement
257+
return getattr(self.requirement, name)
258+
259+
def validate(self, *args, **kwargs):
260+
result = self.requirement.validate(*args, **kwargs)
261+
return ValidationResult(
262+
result = not result,
263+
reason = result.reason,
264+
# score = ???
265+
)
266+
267+
def __and__(self, other):
268+
return ConjunctiveRequirement([self,other])
269+
def __or__(self, other):
270+
return DisjunctiveRequirement([self,other])
271+
def __not__(self):
272+
return self.requirement
273+
274+
176275
def reqify(r: str | Requirement) -> Requirement:
177276
"""Maps strings to Requirements.
178277

0 commit comments

Comments
 (0)