Skip to content

Commit bfb94a8

Browse files
authored
Make APO templates configurable via constructor arguments (#443)
1 parent 25eda47 commit bfb94a8

File tree

1 file changed

+8
-2
lines changed
  • agentlightning/algorithm/apo

1 file changed

+8
-2
lines changed

agentlightning/algorithm/apo/apo.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def __init__(
112112
beam_rounds: int = 3,
113113
rollout_batch_timeout: float = 3600.0,
114114
run_initial_validation: bool = True,
115+
gradient_prompt_files: Optional[List[Path]] = None,
116+
apply_edit_prompt_files: Optional[List[Path]] = None,
115117
# Internal flags for debugging
116118
_poml_trace: bool = False,
117119
):
@@ -132,6 +134,8 @@ def __init__(
132134
rollout_batch_timeout: Maximum time in seconds to wait for rollout batch completion.
133135
run_initial_validation: If True, runs validation on the seed prompt before starting
134136
optimization to establish a baseline score. Defaults to True.
137+
gradient_prompt_files: Prompt templates used to compute textual gradients (critiques).
138+
apply_edit_prompt_files: Prompt templates used to apply edits based on critiques.
135139
"""
136140
self.async_openai_client = async_openai_client
137141
self.gradient_model = gradient_model
@@ -144,6 +148,8 @@ def __init__(
144148
self.beam_rounds = beam_rounds
145149
self.rollout_batch_timeout = rollout_batch_timeout
146150
self.run_initial_validation = run_initial_validation
151+
self.gradient_prompt_files = gradient_prompt_files or GRADIENT_PROMPT_FILES
152+
self.apply_edit_prompt_files = apply_edit_prompt_files or APPLY_EDIT_PROMPT_FILES
147153

148154
self._history_best_prompt: Optional[PromptTemplate] = None
149155
self._history_best_score: float = float("-inf")
@@ -270,7 +276,7 @@ async def compute_textual_gradient(
270276
Returns:
271277
A textual critique generated by the LLM, or None if generation fails.
272278
"""
273-
tg_template = random.choice(GRADIENT_PROMPT_FILES)
279+
tg_template = random.choice(self.gradient_prompt_files)
274280

275281
if len(rollout_results) < self.gradient_batch_size:
276282
self._log(
@@ -352,7 +358,7 @@ async def textual_gradient_and_apply_edit(
352358
return current_prompt.prompt_template.template
353359

354360
# 2) Apply edit
355-
ae_template = random.choice(APPLY_EDIT_PROMPT_FILES)
361+
ae_template = random.choice(self.apply_edit_prompt_files)
356362
self._log(
357363
logging.INFO,
358364
f"Edit will be generated by {self.apply_edit_model} with template: {ae_template.name}",

0 commit comments

Comments
 (0)