@@ -112,6 +112,8 @@ def __init__(
112112 beam_rounds : int = 3 ,
113113 rollout_batch_timeout : float = 3600.0 ,
114114 run_initial_validation : bool = True ,
115+ gradient_prompt_files : Optional [List [Path ]] = None ,
116+ apply_edit_prompt_files : Optional [List [Path ]] = None ,
115117 # Internal flags for debugging
116118 _poml_trace : bool = False ,
117119 ):
@@ -132,6 +134,8 @@ def __init__(
132134 rollout_batch_timeout: Maximum time in seconds to wait for rollout batch completion.
133135 run_initial_validation: If True, runs validation on the seed prompt before starting
134136 optimization to establish a baseline score. Defaults to True.
137+ gradient_prompt_files: Prompt templates used to compute textual gradients (critiques).
138+ apply_edit_prompt_files: Prompt templates used to apply edits based on critiques.
135139 """
136140 self .async_openai_client = async_openai_client
137141 self .gradient_model = gradient_model
@@ -144,6 +148,8 @@ def __init__(
144148 self .beam_rounds = beam_rounds
145149 self .rollout_batch_timeout = rollout_batch_timeout
146150 self .run_initial_validation = run_initial_validation
151+ self .gradient_prompt_files = gradient_prompt_files or GRADIENT_PROMPT_FILES
152+ self .apply_edit_prompt_files = apply_edit_prompt_files or APPLY_EDIT_PROMPT_FILES
147153
148154 self ._history_best_prompt : Optional [PromptTemplate ] = None
149155 self ._history_best_score : float = float ("-inf" )
@@ -270,7 +276,7 @@ async def compute_textual_gradient(
270276 Returns:
271277 A textual critique generated by the LLM, or None if generation fails.
272278 """
273- tg_template = random .choice (GRADIENT_PROMPT_FILES )
279+ tg_template = random .choice (self . gradient_prompt_files )
274280
275281 if len (rollout_results ) < self .gradient_batch_size :
276282 self ._log (
@@ -352,7 +358,7 @@ async def textual_gradient_and_apply_edit(
352358 return current_prompt .prompt_template .template
353359
354360 # 2) Apply edit
355- ae_template = random .choice (APPLY_EDIT_PROMPT_FILES )
361+ ae_template = random .choice (self . apply_edit_prompt_files )
356362 self ._log (
357363 logging .INFO ,
358364 f"Edit will be generated by { self .apply_edit_model } with template: { ae_template .name } " ,
0 commit comments