Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 163 additions & 43 deletions app/desktop/studio_server/gepa_job_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import logging
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Literal, cast

Expand Down Expand Up @@ -40,13 +41,25 @@
)
from fastapi import FastAPI, HTTPException
from kiln_ai.datamodel import GepaJob, Project, Prompt
from kiln_ai.datamodel.prompt import BasePrompt
from kiln_ai.datamodel.task import TaskRunConfig
from kiln_ai.utils.config import Config
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.task_api import task_from_id
from pydantic import BaseModel

logger = logging.getLogger(__name__)

# locks per job ID to prevent race conditions when creating artifacts
_job_locks: dict[str, asyncio.Lock] = {}


def _get_job_lock(job_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific job ID."""
if job_id not in _job_locks:
_job_locks[job_id] = asyncio.Lock()
return _job_locks[job_id]


def is_job_status_final(status: str) -> bool:
"""
Expand Down Expand Up @@ -116,20 +129,154 @@ def gepa_job_from_id(project_id: str, task_id: str, gepa_job_id: str) -> GepaJob
return gepa_job


async def update_gepa_job_status_and_create_prompt(
def create_prompt_from_optimization(
gepa_job: GepaJob, task, optimized_prompt_text: str
) -> Prompt:
"""
Create a prompt from an optimization job result. Does not guarantee idempotence so
make sure you have a proper locking mechanism around calling this function.
"""
prompt = Prompt(
name=f"Kiln Optimized - {gepa_job.name}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is gepa_job.name going to be exactly? Is it an auto generated name like "Sparkling Dolphin"?

Regardless, since I'm now storing kiln optimized generator id we don't need the "Kiln Optimized - " prefix for the name, we will display Kiln Optimized as the Type in the table.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes name is an adjective+noun compound.

I saw we show Kiln Optimized in the table, but not sure about what to call the whole prompt.

Just the compound adj+noun?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes just the adj+noun is good enough here!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it thee the same adj+noun as the job name (does not have to be the same, but makes it easier to trace things back while browsing around)

description=f"Prompt optimized by Kiln Prompt Optimizer {gepa_job.id}",
generator_id="kiln_prompt_optimizer",
prompt=optimized_prompt_text,
parent=task,
)
prompt.save_to_file()

logger.info(f"Created prompt {prompt.id} from GEPA job {gepa_job.job_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really log info in other api files, so maybe remove?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As well as a few other places in this file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the log infos (kept one for the project ZIP intentionally - getting the ZIP size is helpful for debugging if something goes wrong)

return prompt


def create_run_config_from_optimization(
gepa_job: GepaJob, task, optimized_prompt_text: str
) -> TaskRunConfig | None:
"""
Create a run config from an optimization job result. Does not guarantee idempotence so
make sure you have a proper locking mechanism around calling this function.
"""
try:
parent_project = task.parent_project()
if parent_project is None:
raise HTTPException(status_code=500, detail="Task has no parent project")

if not parent_project.id or not task.id:
raise HTTPException(
status_code=500, detail="Task has no parent project or task"
)

# get the original target run config that we optimized for in the job
target_run_config = task_run_config_from_id(
parent_project.id, task.id, gepa_job.target_run_config_id
)

date_str = datetime.now().strftime("%Y-%m-%d")
run_config_name = f"{target_run_config.name} optimized-{date_str}"

frozen_prompt = BasePrompt(
name=f"Kiln Optimized - {run_config_name}",
description=f"Kiln Optimized prompt for {run_config_name}",
generator_id="kiln_prompt_optimizer",
prompt=optimized_prompt_text,
)

# create new run config with the same properties but new prompt
new_run_config_properties = target_run_config.run_config_properties.model_copy()

new_run_config = TaskRunConfig(
parent=task,
name=run_config_name,
description=f"Optimized run config from GEPA job {gepa_job.name}",
run_config_properties=new_run_config_properties,
prompt=frozen_prompt,
)

# point the run config properties to the frozen prompt
new_run_config.run_config_properties.prompt_id = (
f"task_run_config::{parent_project.id}::{task.id}::{new_run_config.id}"
)

new_run_config.save_to_file()

logger.info(
f"Created run config {new_run_config.id} from GEPA job {gepa_job.job_id}"
)
return new_run_config

except Exception as e:
logger.error(f"Error creating run config from GEPA job: {e}", exc_info=True)
return None


async def _create_artifacts_for_succeeded_job(
gepa_job: GepaJob,
task,
server_client: AuthenticatedClient,
) -> None:
"""
Create prompt and run config artifacts for a newly succeeded GEPA job.
Assumes caller has acquired the job lock. Modifies gepa_job in place.
"""
parent_project = task.parent_project()
if not parent_project or not parent_project.id or not task.id or not gepa_job.id:
raise ValueError("Cannot reload GEPA job: missing required IDs")

# reload the job in case artifacts were created by another request while waiting for the lock
reloaded_job = gepa_job_from_id(
parent_project.id,
task.id,
gepa_job.id,
)

# check if artifacts already exist
if reloaded_job.created_prompt_id:
logger.info(
f"Artifacts already exist for GEPA job {gepa_job.job_id}, skipping creation"
)
gepa_job.created_prompt_id = reloaded_job.created_prompt_id
gepa_job.created_run_config_id = reloaded_job.created_run_config_id
gepa_job.optimized_prompt = reloaded_job.optimized_prompt
return

result_response = (
await get_gepa_job_result_v1_jobs_gepa_job_job_id_result_get.asyncio(
job_id=gepa_job.job_id,
client=server_client,
)
)

if (
result_response
and not isinstance(result_response, HTTPValidationError)
and result_response.output
and hasattr(result_response.output, "optimized_prompt")
):
optimized_prompt_text = result_response.output.optimized_prompt
gepa_job.optimized_prompt = optimized_prompt_text

prompt = create_prompt_from_optimization(gepa_job, task, optimized_prompt_text)
gepa_job.created_prompt_id = f"id::{prompt.id}"

run_config = create_run_config_from_optimization(
gepa_job, task, optimized_prompt_text
)
if run_config:
gepa_job.created_run_config_id = run_config.id


async def update_gepa_job_and_create_artifacts(
gepa_job: GepaJob, server_client: AuthenticatedClient
) -> GepaJob:
"""
Update the status of a GepaJob from the remote server.
If the job has succeeded and no prompt exists yet, create one.
If the job has succeeded for the first time, create a prompt and run config from the result.
Uses per-job locking to ensure the success transition is handled atomically.
"""

task = gepa_job.parent_task()
if task is None:
raise HTTPException(status_code=500, detail="GepaJob has no parent task")

previous_status = gepa_job.latest_status

try:
status_response = (
await get_job_status_v1_jobs_job_type_job_id_status_get.asyncio(
Expand All @@ -144,44 +291,19 @@ async def update_gepa_job_status_and_create_prompt(
return gepa_job

new_status = str(status_response.status.value)
gepa_job.latest_status = new_status

if (
previous_status != JobStatus.SUCCEEDED
and new_status == JobStatus.SUCCEEDED
and not gepa_job.created_prompt_id
):
result_response = (
await get_gepa_job_result_v1_jobs_gepa_job_job_id_result_get.asyncio(
job_id=gepa_job.job_id,
client=server_client,
)
)

lock = _get_job_lock(gepa_job.job_id)
async with lock:
previous_status = gepa_job.latest_status
gepa_job.latest_status = new_status

if (
result_response
and not isinstance(result_response, HTTPValidationError)
and result_response.output
and hasattr(result_response.output, "optimized_prompt")
previous_status != JobStatus.SUCCEEDED
and new_status == JobStatus.SUCCEEDED
):
optimized_prompt_text = result_response.output.optimized_prompt
gepa_job.optimized_prompt = optimized_prompt_text

prompt = Prompt(
name=f"GEPA - {gepa_job.name}",
description=f"Optimized prompt generated by GEPA job {gepa_job.id}",
generator_id="kiln_prompt_optimizer",
prompt=optimized_prompt_text,
parent=task,
)
prompt.save_to_file()
await _create_artifacts_for_succeeded_job(gepa_job, task, server_client)

gepa_job.created_prompt_id = f"id::{prompt.id}"
logger.info(
f"Created prompt {prompt.id} from GEPA job {gepa_job.job_id}"
)

gepa_job.save_to_file()
gepa_job.save_to_file()

except Exception as e:
logger.error(f"Error updating GEPA job status: {e}", exc_info=True)
Expand Down Expand Up @@ -524,9 +646,7 @@ async def list_gepa_jobs(
)
await asyncio.gather(
*[
update_gepa_job_status_and_create_prompt(
job, server_client
)
update_gepa_job_and_create_artifacts(job, server_client)
for job in batch
]
)
Expand All @@ -552,7 +672,7 @@ async def get_gepa_job(project_id: str, task_id: str, gepa_job_id: str) -> GepaJ
try:
server_client = get_authenticated_client(_get_api_key())
if isinstance(server_client, AuthenticatedClient):
gepa_job = await update_gepa_job_status_and_create_prompt(
gepa_job = await update_gepa_job_and_create_artifacts(
gepa_job, server_client
)
except Exception as e:
Expand Down
Loading
Loading