Skip to content

Commit e4b7cd4

Browse files
preparing for new sampling: adding a repair field to Instruction that can only be set by .copy_and_repair(...). The templates are updated to accommodate the new field.
1 parent 06a8637 commit e4b7cd4

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

mellea/stdlib/instruction.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Instructions."""
22

3+
from __future__ import annotations
4+
35
from copy import deepcopy
46

57
import jinja2
@@ -106,6 +108,7 @@ def __init__(
106108
self._output_prefix = (
107109
blockify(output_prefix) if output_prefix is not None else None
108110
)
111+
self._repair_string: str | None = None
109112

110113
def parts(self):
111114
"""Returns all of the constituent parts of an Instruction."""
@@ -132,6 +135,7 @@ def format_for_llm(self) -> TemplateRepresentation:
132135
"output_prefix": (
133136
self._output_prefix if self._output_prefix is not None else None
134137
),
138+
"repair": self._repair_string,
135139
},
136140
tools=None,
137141
template_order=["*", "Instruction"],
@@ -147,3 +151,9 @@ def apply_user_dict_from_jinja(user_dict: dict[str, str], s: str) -> str:
147151
def requirements(self) -> list[Requirement]:
148152
"""Returns a list of Requirement instances."""
149153
return self._requirements
154+
155+
def copy_and_repair(self, repair_string: str) -> Instruction:
156+
"""Creates a copy of the instruction and adds/overwrites the repair string."""
157+
res = deepcopy(self)
158+
res._repair_string = repair_string
159+
return res

mellea/templates/prompts/default/Instruction.jinja2

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ Here is some grounding context:
3737
{%- endif -%}
3838
{% endblock grounding_context %}
3939

40+
{%- block repair_block -%}
41+
{% if repair %}
42+
{{ repair -}}
43+
{%- endif -%}
44+
{% endblock repair_block %}
45+
4046
{%- block output_prefix -%}
4147
{% if output_prefix %}
4248
{{ output_prefix -}}
4349
{%- endif -%}
44-
{% endblock output_prefix %}
50+
{% endblock output_prefix %}

mellea/templates/prompts/granite/Instruction.jinja2

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@ Here are some examples of what the response might look like:
3737
{%- endif -%}
3838
{% endblock icl_examples %}
3939

40+
{%- block repair_block -%}
41+
{% if repair %}
42+
{{ repair -}}
43+
{%- endif -%}
44+
{% endblock repair_block %}
45+
4046
{%- block output_prefix -%}
4147
{% if output_prefix %}
4248
{{ output_prefix -}}
4349
{%- endif -%}
44-
{% endblock output_prefix %}
50+
{% endblock output_prefix %}

0 commit comments

Comments
 (0)