-
Notifications
You must be signed in to change notification settings - Fork 347
refactor: prune ZIP for prompt optimization job #1031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8175df0
2e3d690
bc2c663
a8d45a9
6a57b52
3151293
c04ad39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,8 @@ | ||
| import asyncio | ||
| import io | ||
| import logging | ||
| import zipfile | ||
| import tempfile | ||
| from pathlib import Path | ||
| from typing import Literal, cast | ||
|
|
||
| from app.desktop.studio_server.api_client.kiln_ai_server_client.api.jobs import ( | ||
| check_model_supported_v1_jobs_gepa_job_check_model_supported_get, | ||
|
|
@@ -17,9 +16,6 @@ | |
| from app.desktop.studio_server.api_client.kiln_ai_server_client.models.body_start_gepa_job_v1_jobs_gepa_job_start_post import ( | ||
| BodyStartGepaJobV1JobsGepaJobStartPost, | ||
| ) | ||
| from app.desktop.studio_server.api_client.kiln_ai_server_client.models.body_start_gepa_job_v1_jobs_gepa_job_start_post_token_budget import ( | ||
| BodyStartGepaJobV1JobsGepaJobStartPostTokenBudget, | ||
| ) | ||
| from app.desktop.studio_server.api_client.kiln_ai_server_client.models.http_validation_error import ( | ||
| HTTPValidationError, | ||
| ) | ||
|
|
@@ -38,8 +34,13 @@ | |
| eval_from_id, | ||
| task_run_config_from_id, | ||
| ) | ||
| from app.desktop.studio_server.utils.copilot_utils import check_response_error | ||
| from fastapi import FastAPI, HTTPException | ||
| from kiln_ai.datamodel import GepaJob, Project, Prompt | ||
| from kiln_ai.cli.commands.package_project import ( | ||
| PackageForTrainingConfig, | ||
| package_project_for_training, | ||
| ) | ||
| from kiln_ai.datamodel import GepaJob, Prompt | ||
| from kiln_ai.datamodel.task import TaskRunConfig | ||
| from kiln_ai.utils.config import Config | ||
| from kiln_ai.utils.lock import shared_async_lock_manager | ||
|
|
@@ -101,7 +102,6 @@ def _get_api_key() -> str: | |
|
|
||
|
|
||
| class StartGepaJobRequest(BaseModel): | ||
| token_budget: Literal["light", "medium", "heavy"] | ||
| target_run_config_id: str | ||
| eval_ids: list[str] | ||
|
|
||
|
|
@@ -312,52 +312,6 @@ async def update_gepa_job_and_create_artifacts( | |
| return gepa_job | ||
|
|
||
|
|
||
| def zip_project(project: Project) -> bytes: | ||
| """ | ||
| Create a ZIP file of the entire project directory. | ||
| Returns the ZIP file as bytes. | ||
| """ | ||
| if not project.path: | ||
| raise ValueError("Project path is not set") | ||
| project_path = Path(project.path).parent | ||
|
|
||
| # Skip common directories that shouldn't be included | ||
| skip_patterns = { | ||
| ".git", | ||
| "__pycache__", | ||
| ".pytest_cache", | ||
| "node_modules", | ||
| ".venv", | ||
| "venv", | ||
| ".DS_Store", | ||
| ".vscode", | ||
| ".idea", | ||
| } | ||
|
|
||
| buffer = io.BytesIO() | ||
| file_count = 0 | ||
| with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: | ||
| for file_path in project_path.rglob("*"): | ||
| # Skip if any parent directory matches skip patterns | ||
| if any(skip_dir in file_path.parts for skip_dir in skip_patterns): | ||
| continue | ||
|
|
||
| if file_path.is_file(): | ||
| arcname = file_path.relative_to(project_path) | ||
| try: | ||
| zip_file.write(file_path, arcname=arcname) | ||
| file_count += 1 | ||
| except Exception as e: | ||
| logger.warning(f"Skipping file {file_path}: {e}") | ||
|
|
||
| buffer.seek(0) | ||
| zip_bytes = buffer.getvalue() | ||
| logger.info( | ||
| f"Created project ZIP with {file_count} files, total size: {len(zip_bytes)} bytes" | ||
| ) | ||
| return zip_bytes | ||
|
|
||
|
|
||
| def connect_gepa_job_api(app: FastAPI): | ||
| @app.get("/api/projects/{project_id}/tasks/{task_id}/gepa_jobs/check_run_config") | ||
| async def check_run_config( | ||
|
|
@@ -515,7 +469,8 @@ async def start_gepa_job( | |
| Creates and saves a GepaJob datamodel to track the job. | ||
| """ | ||
| task = task_from_id(project_id, task_id) | ||
| if not task.parent: | ||
| project = task.parent_project() | ||
| if not project: | ||
| raise HTTPException(status_code=404, detail="Project not found") | ||
|
|
||
| try: | ||
|
|
@@ -538,49 +493,59 @@ async def start_gepa_job( | |
| status_code=500, detail="Server client not authenticated" | ||
| ) | ||
|
|
||
| # Create ZIP file of the project | ||
| project_zip_bytes = zip_project(cast(Project, task.parent)) | ||
| with tempfile.TemporaryDirectory(prefix="kiln_gepa_") as tmpdir: | ||
| tmp_file = Path(tmpdir) / "kiln_gepa_project.zip" | ||
| package_project_for_training( | ||
| project=project, | ||
| task_ids=[task_id], | ||
| run_config_id=request.target_run_config_id, | ||
| eval_ids=request.eval_ids, | ||
| output=tmp_file, | ||
| config=PackageForTrainingConfig( | ||
| include_documents=False, | ||
| exclude_task_runs=False, | ||
| exclude_eval_config_runs=True, | ||
| ), | ||
| ) | ||
| zip_bytes = tmp_file.read_bytes() | ||
leonardmq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.info( | ||
| f"Created project ZIP, total size: {len(zip_bytes)} bytes and file name: {tmp_file.name}" | ||
| ) | ||
|
|
||
| # Create the File object for the SDK | ||
| project_zip_file = File( | ||
| payload=io.BytesIO(project_zip_bytes), | ||
| file_name="project.zip", | ||
| mime_type="application/zip", | ||
| ) | ||
| project_zip_file = File( | ||
| payload=io.BytesIO(zip_bytes), | ||
| file_name="project.zip", | ||
| mime_type="application/zip", | ||
| ) | ||
|
|
||
| # Create the request body | ||
| body = BodyStartGepaJobV1JobsGepaJobStartPost( | ||
| token_budget=BodyStartGepaJobV1JobsGepaJobStartPostTokenBudget( | ||
| request.token_budget | ||
| ), | ||
| task_id=task_id, | ||
| target_run_config_id=request.target_run_config_id, | ||
| project_zip=project_zip_file, | ||
| eval_ids=request.eval_ids, | ||
| ) | ||
|
|
||
| response = await start_gepa_job_v1_jobs_gepa_job_start_post.asyncio( | ||
| client=server_client, body=body | ||
| ) | ||
|
|
||
| if isinstance(response, HTTPValidationError): | ||
| error_detail = ( | ||
| str(response.detail) | ||
| if hasattr(response, "detail") | ||
| else "Validation error" | ||
| detailed_response = ( | ||
| await start_gepa_job_v1_jobs_gepa_job_start_post.asyncio_detailed( | ||
| client=server_client, body=body | ||
| ) | ||
| raise HTTPException(status_code=422, detail=error_detail) | ||
| ) | ||
| check_response_error( | ||
| detailed_response, | ||
| default_detail="Failed to start GEPA job: unexpected error from server", | ||
| ) | ||
|
|
||
| if response is None: | ||
| response = detailed_response.parsed | ||
| if response is None or isinstance(response, HTTPValidationError): | ||
leonardmq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise HTTPException( | ||
| status_code=500, | ||
| detail="Failed to start GEPA job: No response from server", | ||
| detail="Failed to start GEPA job: unexpected response from server", | ||
|
Comment on lines
+529
to
+543
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Find check_response_error function definition
rg -n -A 30 'def check_response_error' --type pyRepository: Kiln-AI/Kiln Length of output: 1720 HTTPValidationError isinstance check is unreachable for 422 responses
🤖 Prompt for AI Agents |
||
| ) | ||
leonardmq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| gepa_job = GepaJob( | ||
| name=generate_memorable_name(), | ||
| job_id=response.job_id, | ||
| token_budget=request.token_budget, | ||
| target_run_config_id=request.target_run_config_id, | ||
| latest_status=JobStatus.PENDING, | ||
| eval_ids=request.eval_ids, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.