@@ -104,6 +104,9 @@ def __init__(self, config: DatabaseConfig):
104104 if config .db_path and os .path .exists (config .db_path ):
105105 self .load (config .db_path )
106106
107+ # Prompt log
108+ self .prompts_by_program : Dict [str , Dict [str , Dict [str , str ]]] = None
109+
107110 # Set random seed for reproducible sampling if specified
108111 if config .random_seed is not None :
109112 import random
@@ -314,7 +317,14 @@ def save(self, path: Optional[str] = None, iteration: int = 0) -> None:
314317
315318 # Save each program
316319 for program in self .programs .values ():
317- self ._save_program (program , save_path )
320+ prompts = None
321+ if (
322+ self .config .log_prompts
323+ and self .prompts_by_program
324+ and program .id in self .prompts_by_program
325+ ):
326+ prompts = self .prompts_by_program [program .id ]
327+ self ._save_program (program , save_path , prompts = prompts )
318328
319329 # Save metadata
320330 metadata = {
@@ -382,13 +392,19 @@ def load(self, path: str) -> None:
382392
383393 logger .info (f"Loaded database with { len (self .programs )} programs from { path } " )
384394
385- def _save_program (self , program : Program , base_path : Optional [str ] = None ) -> None :
395+ def _save_program (
396+ self ,
397+ program : Program ,
398+ base_path : Optional [str ] = None ,
399+ prompts : Optional [Dict [str , Dict [str , str ]]] = None ,
400+ ) -> None :
386401 """
387402 Save a program to disk
388403
389404 Args:
390405 program: Program to save
391406 base_path: Base path to save to (uses config.db_path if None)
407+ prompts: Optional prompts to save with the program, in the format {template_key: { 'system': str, 'user': str }}
392408 """
393409 save_path = base_path or self .config .db_path
394410 if not save_path :
@@ -399,9 +415,13 @@ def _save_program(self, program: Program, base_path: Optional[str] = None) -> No
399415 os .makedirs (programs_dir , exist_ok = True )
400416
401417 # Save program
418+ program_dict = program .to_dict ()
419+ if prompts :
420+ program_dict ["prompts" ] = prompts
402421 program_path = os .path .join (programs_dir , f"{ program .id } .json" )
422+
403423 with open (program_path , "w" ) as f :
404- json .dump (program . to_dict () , f )
424+ json .dump (program_dict , f )
405425
406426 def _calculate_feature_coords (self , program : Program ) -> List [int ]:
407427 """
@@ -1079,3 +1099,35 @@ def _load_artifact_dir(self, artifact_dir: str) -> Dict[str, Union[str, bytes]]:
10791099 logger .warning (f"Failed to list artifact directory { artifact_dir } : { e } " )
10801100
10811101 return artifacts
1102+
1103+ def log_prompt (
1104+ self ,
1105+ program_id : str ,
1106+ template_key : str ,
1107+ prompt : Dict [str , str ],
1108+ responses : Optional [List [str ]] = None ,
1109+ ) -> None :
1110+ """
1111+ Log a prompt for a program.
1112+ Only logs if self.config.log_prompts is True.
1113+
1114+ Args:
1115+ program_id: ID of the program to log the prompt for
1116+ template_key: Key for the prompt template
1117+ prompt: Prompts in the format {template_key: { 'system': str, 'user': str }}.
1118+ responses: Optional list of responses to the prompt, if available.
1119+ """
1120+
1121+ if not self .config .log_prompts :
1122+ return
1123+
1124+ if responses is None :
1125+ responses = []
1126+ prompt ["responses" ] = responses
1127+
1128+ if self .prompts_by_program is None :
1129+ self .prompts_by_program = {}
1130+
1131+ if program_id not in self .prompts_by_program :
1132+ self .prompts_by_program [program_id ] = {}
1133+ self .prompts_by_program [program_id ][template_key ] = prompt
0 commit comments