Skip to content

Commit 5acc286

Browse files
new sampling strategies (#65)
* Adds a repair field to Instruction that can only be set by `.copy_and_repair(...)`. The templates are updated to accommodate the new field. * new signature for RejectionSampling * adding RejectionSampling and AgenticSampling (aka multi-turn sampling) as subclasses of BaseSampling * fixing requirements and adding tests * refactoring repair and select-from-failure as abtsract methods. --------- Co-authored-by: jakelorocco <[email protected]>
1 parent d92a44f commit 5acc286

File tree

6 files changed

+321
-69
lines changed

6 files changed

+321
-69
lines changed

mellea/stdlib/instruction.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Instructions."""
22

3+
from __future__ import annotations
4+
35
from copy import deepcopy
46

57
import jinja2
@@ -106,6 +108,7 @@ def __init__(
106108
self._output_prefix = (
107109
blockify(output_prefix) if output_prefix is not None else None
108110
)
111+
self._repair_string: str | None = None
109112

110113
def parts(self):
111114
"""Returns all of the constituent parts of an Instruction."""
@@ -132,6 +135,7 @@ def format_for_llm(self) -> TemplateRepresentation:
132135
"output_prefix": (
133136
self._output_prefix if self._output_prefix is not None else None
134137
),
138+
"repair": self._repair_string,
135139
},
136140
tools=None,
137141
template_order=["*", "Instruction"],
@@ -147,3 +151,9 @@ def apply_user_dict_from_jinja(user_dict: dict[str, str], s: str) -> str:
147151
def requirements(self) -> list[Requirement]:
148152
"""Returns a list of Requirement instances."""
149153
return self._requirements
154+
155+
def copy_and_repair(self, repair_string: str) -> Instruction:
156+
"""Creates a copy of the instruction and adds/overwrites the repair string."""
157+
res = deepcopy(self)
158+
res._repair_string = repair_string
159+
return res

0 commit comments

Comments
 (0)