-
Notifications
You must be signed in to change notification settings - Fork 347
gepa: create run config from gepa job completion #1014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
c75b5a4
c03a49c
d71e323
ca1717c
8661445
85fa3f8
4a430bd
07cc3de
38e76e2
4a69d40
e3231b4
48c2eec
77263a8
0492187
b7374d8
95089d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| import io | ||
| import logging | ||
| import zipfile | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| from typing import Literal, cast | ||
|
|
||
|
|
@@ -40,7 +41,10 @@ | |
| ) | ||
| from fastapi import FastAPI, HTTPException | ||
| from kiln_ai.datamodel import GepaJob, Project, Prompt | ||
| from kiln_ai.datamodel.prompt import BasePrompt | ||
| from kiln_ai.datamodel.task import TaskRunConfig | ||
| from kiln_ai.utils.config import Config | ||
| from kiln_ai.utils.lock import shared_async_lock_manager | ||
| from kiln_ai.utils.name_generator import generate_memorable_name | ||
| from kiln_server.task_api import task_from_id | ||
| from pydantic import BaseModel | ||
|
|
@@ -116,20 +120,154 @@ def gepa_job_from_id(project_id: str, task_id: str, gepa_job_id: str) -> GepaJob | |
| return gepa_job | ||
|
|
||
|
|
||
| async def update_gepa_job_status_and_create_prompt( | ||
| def create_prompt_from_optimization( | ||
| gepa_job: GepaJob, task, optimized_prompt_text: str | ||
| ) -> Prompt: | ||
| """ | ||
| Create a prompt from an optimization job result. Does not guarantee idempotence so | ||
| make sure you have a proper locking mechanism around calling this function. | ||
| """ | ||
| prompt = Prompt( | ||
| name=f"Kiln Optimized - {gepa_job.name}", | ||
| description=f"Prompt optimized by Kiln Prompt Optimizer {gepa_job.id}", | ||
| generator_id="kiln_prompt_optimizer", | ||
| prompt=optimized_prompt_text, | ||
| parent=task, | ||
| ) | ||
| prompt.save_to_file() | ||
|
|
||
| logger.info(f"Created prompt {prompt.id} from GEPA job {gepa_job.job_id}") | ||
|
||
| return prompt | ||
|
|
||
|
|
||
| def create_run_config_from_optimization( | ||
| gepa_job: GepaJob, task, optimized_prompt_text: str | ||
| ) -> TaskRunConfig | None: | ||
| """ | ||
| Create a run config from an optimization job result. Does not guarantee idempotence so | ||
| make sure you have a proper locking mechanism around calling this function. | ||
| """ | ||
| try: | ||
| parent_project = task.parent_project() | ||
| if parent_project is None: | ||
| raise HTTPException(status_code=500, detail="Task has no parent project") | ||
|
|
||
| if not parent_project.id or not task.id: | ||
| raise HTTPException( | ||
| status_code=500, detail="Task has no parent project or task" | ||
| ) | ||
|
|
||
| # get the original target run config that we optimized for in the job | ||
| target_run_config = task_run_config_from_id( | ||
| parent_project.id, task.id, gepa_job.target_run_config_id | ||
| ) | ||
|
|
||
| date_str = datetime.now().strftime("%Y-%m-%d") | ||
| run_config_name = f"{target_run_config.name} optimized-{date_str}" | ||
|
|
||
| frozen_prompt = BasePrompt( | ||
| name=f"Kiln Optimized - {run_config_name}", | ||
| description=f"Kiln Optimized prompt for {run_config_name}", | ||
| generator_id="kiln_prompt_optimizer", | ||
| prompt=optimized_prompt_text, | ||
| ) | ||
|
|
||
| # create new run config with the same properties but new prompt | ||
| new_run_config_properties = target_run_config.run_config_properties.model_copy() | ||
|
|
||
| new_run_config = TaskRunConfig( | ||
| parent=task, | ||
| name=run_config_name, | ||
| description=f"Optimized run config from GEPA job {gepa_job.name}", | ||
| run_config_properties=new_run_config_properties, | ||
| prompt=frozen_prompt, | ||
| ) | ||
|
|
||
| # point the run config properties to the frozen prompt | ||
| new_run_config.run_config_properties.prompt_id = ( | ||
| f"task_run_config::{parent_project.id}::{task.id}::{new_run_config.id}" | ||
| ) | ||
|
|
||
| new_run_config.save_to_file() | ||
|
|
||
| logger.info( | ||
| f"Created run config {new_run_config.id} from GEPA job {gepa_job.job_id}" | ||
| ) | ||
| return new_run_config | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error creating run config from GEPA job: {e}", exc_info=True) | ||
| return None | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| async def _create_artifacts_for_succeeded_job( | ||
| gepa_job: GepaJob, | ||
| task, | ||
| server_client: AuthenticatedClient, | ||
| ) -> None: | ||
| """ | ||
| Create prompt and run config artifacts for a newly succeeded GEPA job. | ||
| Assumes caller has acquired the job lock. Modifies gepa_job in place. | ||
| """ | ||
| parent_project = task.parent_project() | ||
| if not parent_project or not parent_project.id or not task.id or not gepa_job.id: | ||
| raise ValueError("Cannot reload GEPA job: missing required IDs") | ||
|
|
||
| # reload the job in case artifacts were created by another request while waiting for the lock | ||
| reloaded_job = gepa_job_from_id( | ||
| parent_project.id, | ||
| task.id, | ||
| gepa_job.id, | ||
| ) | ||
|
|
||
| # check if artifacts already exist | ||
| if reloaded_job.created_prompt_id: | ||
| logger.info( | ||
| f"Artifacts already exist for GEPA job {gepa_job.job_id}, skipping creation" | ||
| ) | ||
| gepa_job.created_prompt_id = reloaded_job.created_prompt_id | ||
| gepa_job.created_run_config_id = reloaded_job.created_run_config_id | ||
| gepa_job.optimized_prompt = reloaded_job.optimized_prompt | ||
| return | ||
|
|
||
| result_response = ( | ||
| await get_gepa_job_result_v1_jobs_gepa_job_job_id_result_get.asyncio( | ||
| job_id=gepa_job.job_id, | ||
| client=server_client, | ||
| ) | ||
| ) | ||
|
|
||
| if ( | ||
| result_response | ||
| and not isinstance(result_response, HTTPValidationError) | ||
| and result_response.output | ||
| and hasattr(result_response.output, "optimized_prompt") | ||
| ): | ||
| optimized_prompt_text = result_response.output.optimized_prompt | ||
| gepa_job.optimized_prompt = optimized_prompt_text | ||
|
|
||
| prompt = create_prompt_from_optimization(gepa_job, task, optimized_prompt_text) | ||
| gepa_job.created_prompt_id = f"id::{prompt.id}" | ||
|
|
||
| run_config = create_run_config_from_optimization( | ||
| gepa_job, task, optimized_prompt_text | ||
| ) | ||
| if run_config: | ||
| gepa_job.created_run_config_id = run_config.id | ||
|
|
||
|
|
||
| async def update_gepa_job_and_create_artifacts( | ||
| gepa_job: GepaJob, server_client: AuthenticatedClient | ||
| ) -> GepaJob: | ||
| """ | ||
| Update the status of a GepaJob from the remote server. | ||
| If the job has succeeded and no prompt exists yet, create one. | ||
| If the job has succeeded for the first time, create a prompt and run config from the result. | ||
| Uses per-job locking to ensure the success transition is handled atomically. | ||
| """ | ||
|
|
||
| task = gepa_job.parent_task() | ||
| if task is None: | ||
| raise HTTPException(status_code=500, detail="GepaJob has no parent task") | ||
|
|
||
| previous_status = gepa_job.latest_status | ||
|
|
||
| try: | ||
| status_response = ( | ||
| await get_job_status_v1_jobs_job_type_job_id_status_get.asyncio( | ||
|
|
@@ -144,44 +282,18 @@ async def update_gepa_job_status_and_create_prompt( | |
| return gepa_job | ||
|
|
||
| new_status = str(status_response.status.value) | ||
| gepa_job.latest_status = new_status | ||
|
|
||
| if ( | ||
| previous_status != JobStatus.SUCCEEDED | ||
| and new_status == JobStatus.SUCCEEDED | ||
| and not gepa_job.created_prompt_id | ||
| ): | ||
| result_response = ( | ||
| await get_gepa_job_result_v1_jobs_gepa_job_job_id_result_get.asyncio( | ||
| job_id=gepa_job.job_id, | ||
| client=server_client, | ||
| ) | ||
| ) | ||
|
|
||
| async with shared_async_lock_manager.acquire(gepa_job.job_id): | ||
| previous_status = gepa_job.latest_status | ||
| gepa_job.latest_status = new_status | ||
|
|
||
| if ( | ||
| result_response | ||
| and not isinstance(result_response, HTTPValidationError) | ||
| and result_response.output | ||
| and hasattr(result_response.output, "optimized_prompt") | ||
| previous_status != JobStatus.SUCCEEDED | ||
| and new_status == JobStatus.SUCCEEDED | ||
| ): | ||
| optimized_prompt_text = result_response.output.optimized_prompt | ||
| gepa_job.optimized_prompt = optimized_prompt_text | ||
|
|
||
| prompt = Prompt( | ||
| name=f"GEPA - {gepa_job.name}", | ||
| description=f"Optimized prompt generated by GEPA job {gepa_job.id}", | ||
| generator_id="kiln_prompt_optimizer", | ||
| prompt=optimized_prompt_text, | ||
| parent=task, | ||
| ) | ||
| prompt.save_to_file() | ||
| await _create_artifacts_for_succeeded_job(gepa_job, task, server_client) | ||
|
|
||
| gepa_job.created_prompt_id = f"id::{prompt.id}" | ||
| logger.info( | ||
| f"Created prompt {prompt.id} from GEPA job {gepa_job.job_id}" | ||
| ) | ||
|
|
||
| gepa_job.save_to_file() | ||
| gepa_job.save_to_file() | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error updating GEPA job status: {e}", exc_info=True) | ||
|
|
@@ -524,9 +636,7 @@ async def list_gepa_jobs( | |
| ) | ||
| await asyncio.gather( | ||
| *[ | ||
| update_gepa_job_status_and_create_prompt( | ||
| job, server_client | ||
| ) | ||
| update_gepa_job_and_create_artifacts(job, server_client) | ||
| for job in batch | ||
| ] | ||
| ) | ||
|
|
@@ -552,7 +662,7 @@ async def get_gepa_job(project_id: str, task_id: str, gepa_job_id: str) -> GepaJ | |
| try: | ||
| server_client = get_authenticated_client(_get_api_key()) | ||
| if isinstance(server_client, AuthenticatedClient): | ||
| gepa_job = await update_gepa_job_status_and_create_prompt( | ||
| gepa_job = await update_gepa_job_and_create_artifacts( | ||
| gepa_job, server_client | ||
| ) | ||
| except Exception as e: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is gepa_job.name going to be exactly? Is it an auto generated name like "Sparkling Dolphin"?
Regardless, since I'm now storing kiln optimized generator id we don't need the "Kiln Optimized - " prefix for the name, we will display Kiln Optimized as the Type in the table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes name is an adjective+noun compound.
I saw we show Kiln Optimized in the table, but not sure about what to call the whole prompt.
Just the compound adj+noun?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes just the adj+noun is good enough here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made it thee the same adj+noun as the job name (does not have to be the same, but makes it easier to trace things back while browsing around)