Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 21 additions & 71 deletions app/desktop/studio_server/copilot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ClarifySpecOutput,
GenerateBatchInput,
GenerateBatchOutput,
HTTPValidationError,
RefineSpecInput,
)
from app.desktop.studio_server.api_client.kiln_ai_server_client.models import (
Expand Down Expand Up @@ -43,11 +42,11 @@
TaskInfoApi,
)
from app.desktop.studio_server.utils.copilot_utils import (
check_response_error,
create_dataset_task_runs,
generate_copilot_examples,
get_copilot_api_key,
)
from app.desktop.studio_server.utils.response_utils import unwrap_response
from fastapi import FastAPI, HTTPException
from kiln_ai.datamodel import TaskRun
from kiln_ai.datamodel.basemodel import FilenameString
Expand Down Expand Up @@ -127,19 +126,10 @@ async def clarify_spec(input: ClarifySpecApiInput) -> ClarifySpecApiOutput:
body=clarify_input,
)
)
check_response_error(detailed_result)

result = detailed_result.parsed
if result is None:
raise HTTPException(
status_code=500, detail="Failed to analyze spec. Please try again."
)

if isinstance(result, HTTPValidationError):
raise HTTPException(
status_code=422,
detail="Validation error.",
)
result = unwrap_response(
detailed_result,
none_detail="Failed to analyze spec. Please try again.",
)

if isinstance(result, ClarifySpecOutput):
return ClarifySpecApiOutput.model_validate(result.to_dict())
Expand All @@ -162,20 +152,10 @@ async def refine_spec(input: RefineSpecApiInput) -> RefineSpecApiOutput:
body=refine_input,
)
)
check_response_error(detailed_result)

result = detailed_result.parsed
if result is None:
raise HTTPException(
status_code=500,
detail="Failed to refine spec with feedback. Please try again.",
)

if isinstance(result, HTTPValidationError):
raise HTTPException(
status_code=422,
detail="Validation error.",
)
result = unwrap_response(
detailed_result,
none_detail="Failed to refine spec with feedback. Please try again.",
)

if isinstance(result, RefineSpecApiOutputClient):
return RefineSpecApiOutput.model_validate(result.to_dict())
Expand All @@ -198,20 +178,10 @@ async def generate_batch(input: GenerateBatchApiInput) -> GenerateBatchApiOutput
body=generate_input,
)
)
check_response_error(detailed_result)

result = detailed_result.parsed
if result is None:
raise HTTPException(
status_code=500,
detail="Failed to generate synthetic data for spec. Please try again.",
)

if isinstance(result, HTTPValidationError):
raise HTTPException(
status_code=422,
detail="Validation error.",
)
result = unwrap_response(
detailed_result,
none_detail="Failed to generate synthetic data for spec. Please try again.",
)

if isinstance(result, GenerateBatchOutput):
return GenerateBatchApiOutput.model_validate(result.to_dict())
Expand All @@ -236,20 +206,10 @@ async def question_spec(
body=questioner_input,
)
)
check_response_error(detailed_result)

result = detailed_result.parsed
if result is None:
raise HTTPException(
status_code=500,
detail="Failed to generate clarifying questions for spec. Please try again.",
)

if isinstance(result, HTTPValidationError):
raise HTTPException(
status_code=422,
detail="Validation error.",
)
result = unwrap_response(
detailed_result,
none_detail="Failed to generate clarifying questions for spec. Please try again.",
)

if isinstance(result, QuestionSetServerApi):
return QuestionSet.model_validate(result.to_dict())
Expand All @@ -272,20 +232,10 @@ async def submit_question_answers(
client=client,
body=submit_input,
)
check_response_error(detailed_result)

result = detailed_result.parsed
if result is None:
raise HTTPException(
status_code=500,
detail="Failed to refine spec with question answers. Please try again.",
)

if isinstance(result, HTTPValidationError):
raise HTTPException(
status_code=422,
detail="Validation error.",
)
result = unwrap_response(
detailed_result,
none_detail="Failed to refine spec with question answers. Please try again.",
)

if isinstance(result, RefineSpecApiOutputClient):
return RefineSpecApiOutput.model_validate(result.to_dict())
Expand Down
104 changes: 28 additions & 76 deletions app/desktop/studio_server/prompt_optimization_job_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.body_start_prompt_optimization_job_v1_jobs_prompt_optimization_job_start_post import (
BodyStartPromptOptimizationJobV1JobsPromptOptimizationJobStartPost,
)
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.http_validation_error import (
HTTPValidationError,
)
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.job_status import (
JobStatus,
)
Expand All @@ -34,7 +31,7 @@
eval_from_id,
task_run_config_from_id,
)
from app.desktop.studio_server.utils.copilot_utils import check_response_error
from app.desktop.studio_server.utils.response_utils import unwrap_response
from fastapi import FastAPI, HTTPException
from kiln_ai.cli.commands.package_project import (
PackageForTrainingConfig,
Expand Down Expand Up @@ -240,17 +237,16 @@ async def _create_artifacts_for_succeeded_job(
prompt_optimization_job.optimized_prompt = reloaded_job.optimized_prompt
return

result_response = await get_prompt_optimization_job_result_v1_jobs_prompt_optimization_job_job_id_result_get.asyncio(
detailed_response = await get_prompt_optimization_job_result_v1_jobs_prompt_optimization_job_job_id_result_get.asyncio_detailed(
job_id=prompt_optimization_job.job_id,
client=server_client,
)
result_response = unwrap_response(
detailed_response,
default_detail="Failed to get Prompt Optimization job result.",
)

if (
result_response
and not isinstance(result_response, HTTPValidationError)
and result_response.output
and hasattr(result_response.output, "optimized_prompt")
):
if result_response.output and result_response.output.optimized_prompt:
optimized_prompt_text = result_response.output.optimized_prompt
prompt_optimization_job.optimized_prompt = optimized_prompt_text

Expand Down Expand Up @@ -295,21 +291,17 @@ async def update_prompt_optimization_job_and_create_artifacts(
)

try:
status_response = (
await get_job_status_v1_jobs_job_type_job_id_status_get.asyncio(
detailed_response = (
await get_job_status_v1_jobs_job_type_job_id_status_get.asyncio_detailed(
job_type=JobType.GEPA_JOB,
job_id=prompt_optimization_job.job_id,
client=server_client,
)
)

if status_response is None or isinstance(status_response, HTTPValidationError):
logger.warning(
f"Could not fetch status for Prompt Optimization job {prompt_optimization_job.job_id}"
)
raise RuntimeError(
f"Could not fetch status for Prompt Optimization job {prompt_optimization_job.job_id}: {status_response}"
)
status_response = unwrap_response(
detailed_response,
default_detail=f"Could not fetch status for Prompt Optimization job {prompt_optimization_job.job_id}",
)

new_status = str(status_response.status.value)

Expand Down Expand Up @@ -370,25 +362,12 @@ async def check_run_config(
status_code=500, detail="Server client not authenticated"
)

response = await check_prompt_optimization_model_supported_v1_jobs_prompt_optimization_job_check_model_supported_get.asyncio(
detailed_response = await check_prompt_optimization_model_supported_v1_jobs_prompt_optimization_job_check_model_supported_get.asyncio_detailed(
client=server_client,
model_name=model_name,
model_provider_name=model_provider.value,
)

if isinstance(response, HTTPValidationError):
error_detail = (
str(response.detail)
if hasattr(response, "detail")
else "Validation error"
)
raise HTTPException(status_code=422, detail=error_detail)

if response is None:
raise HTTPException(
status_code=500,
detail="Failed to check run config: No response from server",
)
response = unwrap_response(detailed_response)

return CheckRunConfigResponse(is_supported=response.is_model_supported)

Expand Down Expand Up @@ -451,25 +430,12 @@ async def check_eval(
)

# EvalConfig.model_provider is already a string, no need for .value
response = await check_prompt_optimization_model_supported_v1_jobs_prompt_optimization_job_check_model_supported_get.asyncio(
detailed_response = await check_prompt_optimization_model_supported_v1_jobs_prompt_optimization_job_check_model_supported_get.asyncio_detailed(
client=server_client,
model_name=model_name,
model_provider_name=model_provider,
)

if isinstance(response, HTTPValidationError):
error_detail = (
str(response.detail)
if hasattr(response, "detail")
else "Validation error"
)
raise HTTPException(status_code=422, detail=error_detail)

if response is None:
raise HTTPException(
status_code=500,
detail="Failed to check eval: No response from server",
)
response = unwrap_response(detailed_response)

return CheckEvalResponse(
has_default_config=True,
Expand Down Expand Up @@ -560,17 +526,7 @@ async def start_prompt_optimization_job(
detailed_response = await start_prompt_optimization_job_v1_jobs_prompt_optimization_job_start_post.asyncio_detailed(
client=server_client, body=body
)
check_response_error(
detailed_response,
default_detail="Failed to start Prompt Optimization job: unexpected error from server",
)

response = detailed_response.parsed
if response is None or isinstance(response, HTTPValidationError):
raise HTTPException(
status_code=500,
detail="Failed to start Prompt Optimization job: unexpected response from server",
)
response = unwrap_response(detailed_response)

prompt_optimization_job = PromptOptimizationJob(
name=generate_memorable_name(),
Expand Down Expand Up @@ -695,17 +651,15 @@ async def get_prompt_optimization_job_status(
status_code=500, detail="Server client not authenticated"
)

response = await get_job_status_v1_jobs_job_type_job_id_status_get.asyncio(
detailed_response = await get_job_status_v1_jobs_job_type_job_id_status_get.asyncio_detailed(
job_type=JobType.GEPA_JOB,
job_id=job_id,
client=server_client,
)

if response is None or isinstance(response, HTTPValidationError):
raise HTTPException(
status_code=404,
detail=f"Prompt Optimization job {job_id} not found",
)
response = unwrap_response(
detailed_response,
default_detail=f"Prompt Optimization job {job_id} not found",
)

return PublicPromptOptimizationJobStatusResponse(
job_id=response.job_id, status=response.status
Expand Down Expand Up @@ -736,16 +690,14 @@ async def get_prompt_optimization_job_result(
status_code=500, detail="Server client not authenticated"
)

response = await get_prompt_optimization_job_result_v1_jobs_prompt_optimization_job_job_id_result_get.asyncio(
detailed_response = await get_prompt_optimization_job_result_v1_jobs_prompt_optimization_job_job_id_result_get.asyncio_detailed(
job_id=job_id,
client=server_client,
)

if response is None or isinstance(response, HTTPValidationError):
raise HTTPException(
status_code=404,
detail=f"Prompt Optimization job {job_id} result not found",
)
response = unwrap_response(
detailed_response,
default_detail=f"Prompt Optimization job {job_id} result not found",
)

if not response.output or not hasattr(response.output, "optimized_prompt"):
raise HTTPException(
Expand Down
18 changes: 6 additions & 12 deletions app/desktop/studio_server/settings_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from app.desktop.studio_server.api_client.kiln_server_client import (
get_authenticated_client,
)
from app.desktop.studio_server.utils.copilot_utils import (
check_response_error,
get_copilot_api_key,
)
from app.desktop.studio_server.utils.copilot_utils import get_copilot_api_key
from app.desktop.studio_server.utils.response_utils import unwrap_response
from fastapi import FastAPI, HTTPException
from kiln_ai.utils.config import Config
from kiln_ai.utils.filesystem import open_folder
Expand Down Expand Up @@ -79,13 +77,9 @@ async def check_entitlements(feature_codes: str) -> dict[str, bool]:
feature_codes=feature_codes,
)
)
check_response_error(detailed_result)

result = detailed_result.parsed
if result is None:
raise HTTPException(
status_code=500,
detail="Failed to check entitlements. Please try again.",
)
result = unwrap_response(
detailed_result,
none_detail="Failed to check entitlements. Please try again.",
)

return result.additional_properties
Loading
Loading