Skip to content

Commit cd51b73

Browse files
author
Tulio Coppola
committed
moves prompt_modules to utils
1 parent d2f8815 commit cd51b73

File tree

113 files changed

+165
-30
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+165
-30
lines changed

cli/decompose/decompose.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import typer
77

8+
from .pipeline import DecompBackend
9+
810
this_file_dir = Path(__file__).resolve().parent
911

1012

@@ -20,15 +22,69 @@ def run(
2022
typer.FileText | None,
2123
typer.Option(help="Path to a raw text file containing a task prompt."),
2224
] = None,
25+
model_id: Annotated[
26+
str,
27+
typer.Option(
28+
help=(
29+
"Model name/id to be used to run the decomposition pipeline."
30+
+ ' Defaults to "mistral-small3.2:latest", which is valid for the "ollama" backend.'
31+
+ " If you have a vLLM instance serving a model from HF with vLLM's OpenAI"
32+
+ " compatible endpoint, then this option should be set to the model's HF name/id,"
33+
+ ' e.g. "mistralai/Mistral-Small-3.2-24B-Instruct-2506" and the "--backend" option'
34+
+ ' should be set to "openai".'
35+
)
36+
),
37+
] = "mistral-small3.2:latest",
38+
backend: Annotated[
39+
DecompBackend,
40+
typer.Option(
41+
help=(
42+
'Backend to be used for inference. Defaults to "ollama".'
43+
+ ' Options are: "ollama" and "openai".'
44+
+ ' The "ollama" backend runs a local inference server.'
45+
+ ' The "openai" backend will send inference requests to any'
46+
+ " endpoint that's OpenAI compatible."
47+
),
48+
case_sensitive=False,
49+
),
50+
] = DecompBackend.ollama,
51+
backend_req_timeout: Annotated[
52+
int,
53+
typer.Option(
54+
help='Time (in seconds) for timeout to be passed on the model inference requests. Defaults to "3600"'
55+
),
56+
] = 3600,
57+
backend_endpoint: Annotated[
58+
str | None,
59+
typer.Option(
60+
help=(
61+
'The "endpoint URL", sometimes called "base URL",'
62+
+ ' to reach the model when using the "openai" backend.'
63+
+ ' This option is required if using "--backend openai".'
64+
)
65+
),
66+
] = None,
67+
backend_api_key: Annotated[
68+
str | None,
69+
typer.Option(
70+
help=(
71+
'The API key for the configured "--backend-endpoint".'
72+
+ ' If using "--backend openai" this option must be set,'
73+
+ " even if you are running locally (an OpenAI compatible server), you"
74+
+ ' must set this option, it can be set to "EMPTY" if your local'
75+
+ " server doesn't need it."
76+
)
77+
),
78+
] = None,
2379
input_var: Annotated[
2480
list[str] | None,
2581
typer.Option(
2682
help=(
27-
"If your task prompt needs user input data, you must pass"
28-
+ " a descriptive variable name using this option,"
29-
+ " so the name can be included when generating the prompt templates."
83+
"If your task needs user input data, you must pass"
84+
+ " a descriptive variable name using this option, this way"
85+
+ " the variable names can be templated into the generated prompts."
3086
+ " You can pass this option multiple times, one for each input variable name."
31-
+ " These names must be all uppercase, alphanumeric with words separated by underscores."
87+
+ " These names must be all uppercase, alphanumeric, with words separated by underscores."
3288
)
3389
),
3490
] = None,
@@ -63,7 +119,13 @@ def run(
63119

64120
if prompt_file:
65121
decomp_data = pipeline.decompose(
66-
task_prompt=prompt_file.read(), user_input_variable=input_var
122+
task_prompt=prompt_file.read(),
123+
user_input_variable=input_var,
124+
model_id=model_id,
125+
backend=backend,
126+
backend_req_timeout=backend_req_timeout,
127+
backend_endpoint=backend_endpoint,
128+
backend_api_key=backend_api_key,
67129
)
68130
else:
69131
task_prompt: str = typer.prompt(
@@ -76,7 +138,13 @@ def run(
76138
)
77139
task_prompt = task_prompt.replace("\\n", "\n")
78140
decomp_data = pipeline.decompose(
79-
task_prompt=task_prompt, user_input_variable=None
141+
task_prompt=task_prompt,
142+
user_input_variable=None,
143+
model_id=model_id,
144+
backend=backend,
145+
backend_req_timeout=backend_req_timeout,
146+
backend_endpoint=backend_endpoint,
147+
backend_api_key=backend_api_key,
80148
)
81149

82150
with open(out_dir / f"{out_name}.json", "w") as f:

cli/decompose/pipeline.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
from enum import Enum
12
from typing import TypedDict
23

34
from typing_extensions import NotRequired
45

56
from mellea import MelleaSession
6-
from mellea.backends import ModelOption
77
from mellea.backends.ollama import OllamaModelBackend
8-
from mellea.prompt_modules import (
8+
from mellea.backends.openai import OpenAIBackend
9+
from mellea.backends.types import ModelOption
10+
from mellea.helpers.prompt_modules import (
911
constraint_extractor,
1012
subtask_constraint_assign,
1113
subtask_list,
1214
subtask_prompt_generator,
1315
)
14-
from mellea.prompt_modules.subtask_constraint_assign import SubtaskPromptConstraintsItem
15-
from mellea.prompt_modules.subtask_list import SubtaskItem
16-
from mellea.prompt_modules.subtask_prompt_generator import SubtaskPromptItem
16+
from mellea.helpers.prompt_modules.subtask_constraint_assign import (
17+
SubtaskPromptConstraintsItem,
18+
)
19+
from mellea.helpers.prompt_modules.subtask_list import SubtaskItem
20+
from mellea.helpers.prompt_modules.subtask_prompt_generator import SubtaskPromptItem
1721

1822

1923
class DecompSubtasksResult(TypedDict):
@@ -32,37 +36,84 @@ class DecompPipelineResult(TypedDict):
3236
final_response: NotRequired[str]
3337

3438

39+
class DecompBackend(str, Enum):
40+
ollama = "ollama"
41+
openai = "openai"
42+
rits = "rits"
43+
44+
3545
def decompose(
36-
task_prompt: str, user_input_variable: list[str] | None = None
46+
task_prompt: str,
47+
user_input_variable: list[str] | None = None,
48+
model_id: str = "mistral-small3.2:latest",
49+
backend: DecompBackend = DecompBackend.ollama,
50+
backend_req_timeout: int = 3600,
51+
backend_endpoint: str | None = None,
52+
backend_api_key: str | None = None,
3753
) -> DecompPipelineResult:
3854
if user_input_variable is None:
3955
user_input_variable = []
4056

41-
m_ollama_session = MelleaSession(
42-
OllamaModelBackend(
43-
model_id="mistral-small3.2:24b",
44-
model_options={ModelOption.CONTEXT_WINDOW: 32768},
45-
)
46-
)
57+
match backend:
58+
case DecompBackend.ollama:
59+
m_session = MelleaSession(
60+
OllamaModelBackend(
61+
model_id=model_id,
62+
model_options={
63+
ModelOption.CONTEXT_WINDOW: 32768,
64+
"timeout": backend_req_timeout,
65+
},
66+
)
67+
)
68+
case DecompBackend.openai:
69+
assert backend_endpoint is not None, (
70+
'Required to provide "backend_endpoint" for this configuration'
71+
)
72+
assert backend_api_key is not None, (
73+
'Required to provide "backend_api_key" for this configuration'
74+
)
75+
m_session = MelleaSession(
76+
OpenAIBackend(
77+
model_id=model_id,
78+
base_url=backend_endpoint,
79+
api_key=backend_api_key,
80+
model_options={"timeout": backend_req_timeout},
81+
)
82+
)
83+
case DecompBackend.rits:
84+
assert backend_endpoint is not None, (
85+
'Required to provide "backend_endpoint" for this configuration'
86+
)
87+
assert backend_api_key is not None, (
88+
'Required to provide "backend_api_key" for this configuration'
89+
)
4790

48-
subtasks: list[SubtaskItem] = subtask_list.generate(
49-
m_ollama_session, task_prompt
50-
).parse()
91+
from mellea_ibm.rits import RITSBackend, RITSModelIdentifier # type: ignore
92+
93+
m_session = MelleaSession(
94+
RITSBackend(
95+
RITSModelIdentifier(endpoint=backend_endpoint, model_name=model_id),
96+
api_key=backend_api_key,
97+
model_options={"timeout": backend_req_timeout},
98+
)
99+
)
100+
101+
subtasks: list[SubtaskItem] = subtask_list.generate(m_session, task_prompt).parse()
51102

52103
task_prompt_constraints: list[str] = constraint_extractor.generate(
53-
m_ollama_session, task_prompt
104+
m_session, task_prompt
54105
).parse()
55106

56107
subtask_prompts: list[SubtaskPromptItem] = subtask_prompt_generator.generate(
57-
m_ollama_session,
108+
m_session,
58109
task_prompt,
59110
user_input_var_names=user_input_variable,
60111
subtasks_and_tags=subtasks,
61112
).parse()
62113

63114
subtask_prompts_with_constraints: list[SubtaskPromptConstraintsItem] = (
64115
subtask_constraint_assign.generate(
65-
m_ollama_session,
116+
m_session,
66117
subtasks_tags_and_prompts=subtask_prompts,
67118
constraint_list=task_prompt_constraints,
68119
).parse()
File renamed without changes.

mellea/prompt_modules/constraint_extractor/_constraint_extractor.py renamed to mellea/helpers/prompt_modules/constraint_extractor/_constraint_extractor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
from mellea import MelleaSession
66
from mellea.backends.types import ModelOption
7-
from mellea.prompt_modules._prompt_modules import PromptModule, PromptModuleString
7+
from mellea.helpers.prompt_modules._prompt_modules import (
8+
PromptModule,
9+
PromptModuleString,
10+
)
811
from mellea.stdlib.instruction import Instruction
912

1013
from ._exceptions import BackendGenerationError, TagExtractionError

0 commit comments

Comments
 (0)