Skip to content

Commit b8fc8e1

Browse files
feat: decompose cli tool enhancements & new prompt_modules (#170)
* fixes minor typo Signed-off-by: Tulio Coppola <[email protected]> * places system prompts on their correct place Signed-off-by: Tulio Coppola <[email protected]> * new "general_instructions" prompt_module & m_decomp template versioning Signed-off-by: Tulio Coppola <[email protected]> * new validation_decision module & adjustments Signed-off-by: Tulio Coppola <[email protected]> * minor adjustments Signed-off-by: Tulio Coppola <[email protected]> * adjustments for intermediate release Signed-off-by: Tulio Coppola <[email protected]> * fix: avoid instruction template on decompose tool Signed-off-by: Tulio Coppola <[email protected]> --------- Signed-off-by: Tulio Coppola <[email protected]> Co-authored-by: jakelorocco <[email protected]>
1 parent 689e1a9 commit b8fc8e1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+908
-106
lines changed

cli/decompose/decompose.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import json
22
import keyword
3+
from enum import Enum
34
from pathlib import Path
45
from typing import Annotated
56

67
import typer
78

89
from .pipeline import DecompBackend
910

11+
12+
# Must maintain declaration order
13+
# Newer versions must be declared on the bottom
14+
class DecompVersion(str, Enum):
15+
latest = "latest"
16+
v1 = "v1"
17+
# v2 = "v2"
18+
19+
1020
this_file_dir = Path(__file__).resolve().parent
1121

1222

@@ -76,6 +86,13 @@ def run(
7686
)
7787
),
7888
] = None,
89+
version: Annotated[
90+
DecompVersion,
91+
typer.Option(
92+
help=("Version of the mellea program generator template to be used."),
93+
case_sensitive=False,
94+
),
95+
] = DecompVersion.latest,
7996
input_var: Annotated[
8097
list[str] | None,
8198
typer.Option(
@@ -99,7 +116,13 @@ def run(
99116
environment = Environment(
100117
loader=FileSystemLoader(this_file_dir), autoescape=False
101118
)
102-
m_template = environment.get_template("m_decomp_result.py.jinja2")
119+
120+
ver = (
121+
list(DecompVersion)[-1].value
122+
if version == DecompVersion.latest
123+
else version.value
124+
)
125+
m_template = environment.get_template(f"m_decomp_result_{ver}.py.jinja2")
103126

104127
out_name = out_name.strip()
105128
assert validate_filename(out_name), (

cli/decompose/m_decomp_result.py.jinja2 renamed to cli/decompose/m_decomp_result_v1.py.jinja2

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@ except KeyError as e:
1818
print(f"ERROR: One or more required environment variables are not set; {e}")
1919
exit(1)
2020
{%- endif %}
21-
{% for item in subtasks%}
21+
{% for item in subtasks %}
2222
{% set i = loop.index0 %}
2323
# {{ item.subtask }} - {{ item.tag }}
24-
subtask_{{ loop.index }} = m.instruct(
24+
{{ item.tag | lower }} = m.instruct(
2525
textwrap.dedent(
2626
R"""
27-
{{ item.prompt_template | trim | indent(width=8, first=False) }}
27+
{{ item.prompt_template | trim | indent(width=8, first=False) }}
2828
""".strip()
2929
),
3030
{%- if item.constraints %}
3131
requirements=[
32-
{%- for con in item.constraints %}
33-
{{ con | tojson}},
32+
{%- for c in item.constraints %}
33+
{{ c.constraint | tojson}},
3434
{%- endfor %}
3535
],
3636
{%- else %}
@@ -39,22 +39,22 @@ subtask_{{ loop.index }} = m.instruct(
3939
{%- if loop.first and not user_inputs %}
4040
{%- else %}
4141
user_variables={
42-
{%- if user_inputs %}
43-
{%- for var in user_inputs %}
42+
{%- for var in item.input_vars_required %}
4443
{{ var | upper | tojson }}: {{ var | lower }},
4544
{%- endfor %}
46-
{%- endif %}
4745

48-
{%- for j in range(i) %}
49-
{{ subtasks[j].tag | tojson }}: subtask_{{ i }}.value if subtask_{{ i }}.value is not None else "",
46+
{%- for var in item.depends_on %}
47+
{{ var | upper | tojson }}: {{ var | lower }}.value,
5048
{%- endfor %}
5149
},
5250
{%- endif %}
5351
)
52+
assert {{ item.tag | lower }}.value is not None, 'ERROR: task "{{ item.tag | lower }}" execution failed'
5453
{%- if loop.last %}
5554

56-
final_response = subtask_{{ loop.index }}.value
5755

58-
print(final_response)
56+
final_answer = {{ item.tag | lower }}.value
57+
58+
print(final_answer)
5959
{%- endif -%}
6060
{%- endfor -%}

cli/decompose/pipeline.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import re
12
from enum import Enum
2-
from typing import TypedDict
3+
from typing import Literal, TypedDict
34

45
from typing_extensions import NotRequired
56

@@ -10,27 +11,37 @@
1011

1112
from .prompt_modules import (
1213
constraint_extractor,
14+
# general_instructions,
1315
subtask_constraint_assign,
1416
subtask_list,
1517
subtask_prompt_generator,
18+
validation_decision,
1619
)
1720
from .prompt_modules.subtask_constraint_assign import SubtaskPromptConstraintsItem
1821
from .prompt_modules.subtask_list import SubtaskItem
1922
from .prompt_modules.subtask_prompt_generator import SubtaskPromptItem
2023

2124

25+
class ConstraintResult(TypedDict):
26+
constraint: str
27+
validation_strategy: str
28+
29+
2230
class DecompSubtasksResult(TypedDict):
2331
subtask: str
2432
tag: str
25-
constraints: list[str]
33+
constraints: list[ConstraintResult]
2634
prompt_template: str
35+
# general_instructions: str
36+
input_vars_required: list[str]
37+
depends_on: list[str]
2738
generated_response: NotRequired[str]
2839

2940

3041
class DecompPipelineResult(TypedDict):
3142
original_task_prompt: str
3243
subtask_list: list[str]
33-
identified_constraints: list[str]
44+
identified_constraints: list[ConstraintResult]
3445
subtasks: list[DecompSubtasksResult]
3546
final_response: NotRequired[str]
3647

@@ -41,6 +52,9 @@ class DecompBackend(str, Enum):
4152
rits = "rits"
4253

4354

55+
RE_JINJA_VAR = re.compile(r"\{\{\s*(.*?)\s*\}\}")
56+
57+
4458
def decompose(
4559
task_prompt: str,
4660
user_input_variable: list[str] | None = None,
@@ -53,15 +67,12 @@ def decompose(
5367
if user_input_variable is None:
5468
user_input_variable = []
5569

70+
# region Backend Assignment
5671
match backend:
5772
case DecompBackend.ollama:
5873
m_session = MelleaSession(
5974
OllamaModelBackend(
60-
model_id=model_id,
61-
model_options={
62-
ModelOption.CONTEXT_WINDOW: 32768,
63-
"timeout": backend_req_timeout,
64-
},
75+
model_id=model_id, model_options={ModelOption.CONTEXT_WINDOW: 16384}
6576
)
6677
)
6778
case DecompBackend.openai:
@@ -96,13 +107,19 @@ def decompose(
96107
model_options={"timeout": backend_req_timeout},
97108
)
98109
)
110+
# endregion
99111

100112
subtasks: list[SubtaskItem] = subtask_list.generate(m_session, task_prompt).parse()
101113

102114
task_prompt_constraints: list[str] = constraint_extractor.generate(
103-
m_session, task_prompt
115+
m_session, task_prompt, enforce_same_words=False
104116
).parse()
105117

118+
constraint_validation_strategies: dict[str, Literal["code", "llm"]] = {
119+
cons_key: validation_decision.generate(m_session, cons_key).parse()
120+
for cons_key in task_prompt_constraints
121+
}
122+
106123
subtask_prompts: list[SubtaskPromptItem] = subtask_prompt_generator.generate(
107124
m_session,
108125
task_prompt,
@@ -122,15 +139,52 @@ def decompose(
122139
DecompSubtasksResult(
123140
subtask=subtask_data.subtask,
124141
tag=subtask_data.tag,
125-
constraints=subtask_data.constraints,
142+
constraints=[
143+
{
144+
"constraint": cons_str,
145+
"validation_strategy": constraint_validation_strategies[cons_str],
146+
}
147+
for cons_str in subtask_data.constraints
148+
],
126149
prompt_template=subtask_data.prompt_template,
150+
# general_instructions=general_instructions.generate(
151+
# m_session, input_str=subtask_data.prompt_template
152+
# ).parse(),
153+
input_vars_required=list(
154+
dict.fromkeys( # Remove duplicates while preserving the original order.
155+
[
156+
item
157+
for item in re.findall(
158+
RE_JINJA_VAR, subtask_data.prompt_template
159+
)
160+
if item in user_input_variable
161+
]
162+
)
163+
),
164+
depends_on=list(
165+
dict.fromkeys( # Remove duplicates while preserving the original order.
166+
[
167+
item
168+
for item in re.findall(
169+
RE_JINJA_VAR, subtask_data.prompt_template
170+
)
171+
if item not in user_input_variable
172+
]
173+
)
174+
),
127175
)
128176
for subtask_data in subtask_prompts_with_constraints
129177
]
130178

131179
return DecompPipelineResult(
132180
original_task_prompt=task_prompt,
133181
subtask_list=[item.subtask for item in subtasks],
134-
identified_constraints=task_prompt_constraints,
182+
identified_constraints=[
183+
{
184+
"constraint": cons_str,
185+
"validation_strategy": constraint_validation_strategies[cons_str],
186+
}
187+
for cons_str in task_prompt_constraints
188+
],
135189
subtasks=decomp_subtask_result,
136190
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .constraint_extractor import constraint_extractor as constraint_extractor
2+
from .general_instructions import general_instructions as general_instructions
23
from .subtask_constraint_assign import (
34
subtask_constraint_assign as subtask_constraint_assign,
45
)
56
from .subtask_list import subtask_list as subtask_list
67
from .subtask_prompt_generator import (
78
subtask_prompt_generator as subtask_prompt_generator,
89
)
10+
from .validation_decision import validation_decision as validation_decision

cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from mellea import MelleaSession
66
from mellea.backends.types import ModelOption
7-
from mellea.stdlib.base import CBlock
8-
from mellea.stdlib.instruction import Instruction
7+
from mellea.stdlib.chat import Message
98

109
from .._prompt_modules import PromptModule, PromptModuleString
1110
from ._exceptions import BackendGenerationError, TagExtractionError
@@ -14,7 +13,7 @@
1413
T = TypeVar("T")
1514

1615
RE_VERIFIED_CONS_COND = re.compile(
17-
r"<constraints_and_conditions>(.+?)</constraints_and_conditions>",
16+
r"<constraints_and_requirements>(.+?)</constraints_and_requirements>",
1817
flags=re.IGNORECASE | re.DOTALL,
1918
)
2019

@@ -33,13 +32,13 @@ def _default_parser(generated_str: str) -> list[str]:
3332
generated_str (`str`): The LLM's answer to be parsed.
3433
3534
Returns:
36-
list[str]: A list of identified constraints in natural language. The list
35+
list[str]: A list of identified constraints and requirements in natural language. The list
3736
will be empty if no constraints were identified by the LLM.
3837
3938
Raises:
4039
TagExtractionError: An error occurred trying to extract content from the
4140
generated output. The LLM probably failed to open and close
42-
the \<constraints_and_conditions\> tags.
41+
the \<constraints_and_requirements\> tags.
4342
"""
4443
constraint_extractor_match = re.search(RE_VERIFIED_CONS_COND, generated_str)
4544

@@ -51,7 +50,7 @@ def _default_parser(generated_str: str) -> list[str]:
5150

5251
if constraint_extractor_str is None:
5352
raise TagExtractionError(
54-
'LLM failed to generate correct tags for extraction: "<constraints_and_conditions>"'
53+
'LLM failed to generate correct tags for extraction: "<constraints_and_requirements>"'
5554
)
5655

5756
# TODO: Maybe replace this logic with a RegEx?
@@ -76,13 +75,13 @@ def generate( # type: ignore[override]
7675
self,
7776
mellea_session: MelleaSession,
7877
input_str: str | None,
79-
max_new_tokens: int = 8192,
78+
max_new_tokens: int = 4096,
8079
parser: Callable[[str], T] = _default_parser, # type: ignore[assignment]
8180
# About the mypy ignore above: https://github.com/python/mypy/issues/3737
8281
enforce_same_words: bool = False,
8382
**kwargs: dict[str, Any],
8483
) -> PromptModuleString[T]:
85-
"""Generates an unordered list of identified constraints based on a provided task prompt.
84+
"""Generates an unordered list of identified constraints and requirements based on a provided task prompt.
8685
8786
_**Disclaimer**: This is a LLM-prompting module, so the results will vary depending
8887
on the size and capabilities of the LLM used. The results are also not guaranteed, so
@@ -112,12 +111,13 @@ def generate( # type: ignore[override]
112111
system_prompt = get_system_prompt(enforce_same_words=enforce_same_words)
113112
user_prompt = get_user_prompt(task_prompt=input_str)
114113

115-
instruction = Instruction(description=user_prompt, prefix=system_prompt)
114+
action = Message("user", user_prompt)
116115

117116
try:
118117
gen_result = mellea_session.act(
119-
action=instruction,
118+
action=action,
120119
model_options={
120+
ModelOption.SYSTEM_PROMPT: system_prompt,
121121
ModelOption.TEMPERATURE: 0,
122122
ModelOption.MAX_NEW_TOKENS: max_new_tokens,
123123
},

cli/decompose/prompt_modules/constraint_extractor/_prompt/_icl_examples/_example_1/_example.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@
99

1010
example: ICLExample = {
1111
"task_prompt": task_prompt.strip(),
12-
"constraints_and_conditions": [],
12+
"constraints_and_requirements": [],
1313
}
1414

15-
example["constraints_and_conditions"] = [
15+
example["constraints_and_requirements"] = [
1616
"Your answers should not include harmful, unethical, racist, sexist, toxic, dangerous, or illegal content",
17-
"If a question does not make sense, or not factually coherent, explain to the user why, instead of just answering something incorrect",
1817
"You must always answer the user with markdown formatting",
19-
"The markdown formats you can use are the following: heading; link; table; list; code block; block quote; bold; italic",
20-
"When answering with code blocks, include the language",
18+
"The only markdown formats you can use are the following: heading; link; table; list; code block; block quote; bold; italic",
2119
"All HTML tags must be enclosed in block quotes",
2220
"The personas must include the following properties: name; age; occupation; demographics; goals; behaviors; pain points; motivations",
2321
"The assistant must provide a comprehensive understanding of the target audience",

cli/decompose/prompt_modules/constraint_extractor/_prompt/_icl_examples/_example_1/task_prompt.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ If a question does not make sense, or not factually coherent, explain to the use
55

66
You must always answer the user with markdown formatting.
77

8-
The markdown formats you can use are the following:
8+
The only markdown formats you can use are the following:
99
- heading
1010
- link
1111
- table
@@ -15,7 +15,6 @@ The markdown formats you can use are the following:
1515
- bold
1616
- italic
1717

18-
When answering with code blocks, include the language.
1918
You can be penalized if you write code outside of code blocks.
2019

2120
All HTML tags must be enclosed in block quotes, for example:

cli/decompose/prompt_modules/constraint_extractor/_prompt/_icl_examples/_example_2/_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99

1010
example: ICLExample = {
1111
"task_prompt": task_prompt.strip(),
12-
"constraints_and_conditions": [],
12+
"constraints_and_requirements": [],
1313
}
1414

15-
example["constraints_and_conditions"] = [
16-
"Emphasize the responsibilities and support offered to survivors of crime",
15+
example["constraints_and_requirements"] = [
1716
"Ensure the word 'assistance' appears less than 4 times",
1817
"Wrap the entire response with double quotation marks",
1918
]

0 commit comments

Comments
 (0)