Skip to content

Commit 10e321a

Browse files
committed
Merge branch 'sfierro/specs_server_apis_use' into sfierro/specs_server_apis_python_sdk
2 parents 6727774 + 5869891 commit 10e321a

File tree

3 files changed

+160
-74
lines changed

3 files changed

+160
-74
lines changed

app/desktop/studio_server/copilot_api.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any
2-
31
from app.desktop.studio_server.api_client.kiln_ai_server_client.api.copilot import (
42
clarify_spec_v1_copilot_clarify_spec_post,
53
generate_batch_v1_copilot_generate_batch_post,
@@ -8,23 +6,43 @@
86
from app.desktop.studio_server.api_client.kiln_ai_server_client.models import (
97
ClarifySpecInput,
108
ClarifySpecOutput,
11-
ExampleWithFeedback,
129
GenerateBatchInput,
1310
GenerateBatchOutput,
1411
HTTPValidationError,
1512
RefineSpecInput,
1613
RefineSpecOutput,
17-
SpecInfo,
18-
TaskInfo,
1914
)
2015
from app.desktop.studio_server.api_client.kiln_server_client import (
2116
get_authenticated_client,
2217
)
2318
from fastapi import FastAPI, HTTPException
19+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
2420
from kiln_ai.utils.config import Config
2521
from pydantic import BaseModel, Field
2622

2723

24+
# Pydantic input models (replacing attrs-based client models)
25+
class TaskInfoApi(BaseModel):
26+
task_prompt: str
27+
few_shot_examples: str | None = None
28+
29+
30+
class SpecInfoApi(BaseModel):
31+
spec_fields: dict[str, str]
32+
spec_field_current_values: dict[str, str]
33+
34+
35+
class ExampleWithFeedbackApi(BaseModel):
36+
user_rating_exhibits_issue_correct: bool
37+
input: str = Field(alias="input")
38+
output: str
39+
exhibits_issue: bool
40+
user_feedback: str | None = None
41+
42+
class Config:
43+
populate_by_name = True
44+
45+
2846
class ClarifySpecApiInput(BaseModel):
2947
task_prompt_with_few_shot: str
3048
task_input_schema: str
@@ -39,9 +57,9 @@ class RefineSpecApiInput(BaseModel):
3957
task_prompt_with_few_shot: str
4058
task_input_schema: str
4159
task_output_schema: str
42-
task_info: TaskInfo
43-
spec: SpecInfo
44-
examples_with_feedback: list[ExampleWithFeedback]
60+
task_info: TaskInfoApi
61+
spec: SpecInfoApi
62+
examples_with_feedback: list[ExampleWithFeedbackApi]
4563

4664

4765
class GenerateBatchApiInput(BaseModel):
@@ -54,6 +72,49 @@ class GenerateBatchApiInput(BaseModel):
5472
enable_scoring: bool = Field(default=False)
5573

5674

75+
class SubsampleBatchOutputItemApi(BaseModel):
76+
input: str = Field(alias="input")
77+
output: str
78+
exhibits_issue: bool
79+
80+
81+
class ClarifySpecApiOutput(BaseModel):
82+
examples_for_feedback: list[SubsampleBatchOutputItemApi]
83+
model_id: str
84+
model_provider: ModelProviderName
85+
judge_prompt: str
86+
87+
88+
class SpecEditApi(BaseModel):
89+
old_value: str
90+
proposed_edit: str
91+
reason_for_edit: str
92+
93+
94+
class RefineSpecApiOutput(BaseModel):
95+
new_proposed_spec_edits: dict[str, SpecEditApi]
96+
out_of_scope_feedback: str
97+
98+
99+
class SampleApi(BaseModel):
100+
input: str = Field(alias="input")
101+
output: str
102+
103+
104+
class ScoredSampleApi(BaseModel):
105+
input: str = Field(alias="input")
106+
output: str
107+
exhibits_issue: bool
108+
reasoning: str
109+
110+
111+
class GenerateBatchApiOutput(BaseModel):
112+
data_by_topic: dict[str, list[SampleApi | ScoredSampleApi]]
113+
topic_gen_prompt: str | None = None
114+
input_gen_prompt: str | None = None
115+
judge_prompt: str | None = None
116+
117+
57118
def _get_api_key() -> str:
58119
"""Get the Kiln Copilot API key from config, raising an error if not set."""
59120
api_key = Config.shared().kiln_copilot_api_key
@@ -67,7 +128,7 @@ def _get_api_key() -> str:
67128

68129
def connect_copilot_api(app: FastAPI):
69130
@app.post("/api/copilot/clarify_spec")
70-
async def clarify_spec(input: ClarifySpecApiInput) -> dict[str, Any]:
131+
async def clarify_spec(input: ClarifySpecApiInput) -> ClarifySpecApiOutput:
71132
api_key = _get_api_key()
72133
client = get_authenticated_client(api_key)
73134

@@ -90,19 +151,19 @@ async def clarify_spec(input: ClarifySpecApiInput) -> dict[str, Any]:
90151
)
91152

92153
if isinstance(result, ClarifySpecOutput):
93-
return result.to_dict()
154+
return ClarifySpecApiOutput.model_validate(result.to_dict())
94155

95156
raise HTTPException(
96157
status_code=500,
97158
detail=f"Failed to clarify spec: Unexpected response type {type(result)}",
98159
)
99160

100161
@app.post("/api/copilot/refine_spec")
101-
async def refine_spec(input: RefineSpecApiInput) -> dict[str, Any]:
162+
async def refine_spec(input: RefineSpecApiInput) -> RefineSpecApiOutput:
102163
api_key = _get_api_key()
103164
client = get_authenticated_client(api_key)
104165

105-
refine_input = RefineSpecInput(**input.model_dump())
166+
refine_input = RefineSpecInput.from_dict(input.model_dump(by_alias=True))
106167

107168
result = await refine_spec_v1_copilot_refine_spec_post.asyncio(
108169
client=client,
@@ -121,15 +182,15 @@ async def refine_spec(input: RefineSpecApiInput) -> dict[str, Any]:
121182
)
122183

123184
if isinstance(result, RefineSpecOutput):
124-
return result.to_dict()
185+
return RefineSpecApiOutput.model_validate(result.to_dict())
125186

126187
raise HTTPException(
127188
status_code=500,
128189
detail=f"Failed to refine spec: Unexpected response type {type(result)}",
129190
)
130191

131192
@app.post("/api/copilot/generate_batch")
132-
async def generate_batch(input: GenerateBatchApiInput) -> dict[str, Any]:
193+
async def generate_batch(input: GenerateBatchApiInput) -> GenerateBatchApiOutput:
133194
api_key = _get_api_key()
134195
client = get_authenticated_client(api_key)
135196

@@ -152,7 +213,7 @@ async def generate_batch(input: GenerateBatchApiInput) -> dict[str, Any]:
152213
)
153214

154215
if isinstance(result, GenerateBatchOutput):
155-
return result.to_dict()
216+
return GenerateBatchApiOutput.model_validate(result.to_dict())
156217

157218
raise HTTPException(
158219
status_code=500,

app/desktop/studio_server/test_copilot_api.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from unittest.mock import AsyncMock, patch
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
4-
from app.desktop.studio_server.api_client.kiln_ai_server_client.models import (
4+
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.clarify_spec_output import (
55
ClarifySpecOutput,
6+
)
7+
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.generate_batch_output import (
68
GenerateBatchOutput,
7-
GenerateBatchOutputDataByTopic,
9+
)
10+
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.http_validation_error import (
811
HTTPValidationError,
9-
ModelProviderName,
12+
)
13+
from app.desktop.studio_server.api_client.kiln_ai_server_client.models.refine_spec_output import (
1014
RefineSpecOutput,
11-
RefineSpecOutputNewProposedSpecEdits,
12-
SubsampleBatchOutputItem,
1315
)
1416
from app.desktop.studio_server.copilot_api import connect_copilot_api
1517
from fastapi import FastAPI
@@ -101,18 +103,19 @@ def test_clarify_spec_no_api_key(self, client, clarify_spec_input):
101103
assert "API key not configured" in response.json()["detail"]
102104

103105
def test_clarify_spec_success(self, client, clarify_spec_input, mock_api_key):
104-
mock_output = ClarifySpecOutput(
105-
examples_for_feedback=[
106-
SubsampleBatchOutputItem(
107-
input_="test input",
108-
output="test output",
109-
exhibits_issue=False,
110-
)
106+
mock_output = MagicMock(spec=ClarifySpecOutput)
107+
mock_output.to_dict.return_value = {
108+
"examples_for_feedback": [
109+
{
110+
"input": "test input",
111+
"output": "test output",
112+
"exhibits_issue": False,
113+
}
111114
],
112-
model_id="gpt-4",
113-
model_provider=ModelProviderName.OPENAI,
114-
judge_prompt="Test judge prompt",
115-
)
115+
"model_id": "gpt-4",
116+
"model_provider": "openai",
117+
"judge_prompt": "Test judge prompt",
118+
}
116119

117120
with patch(
118121
"app.desktop.studio_server.copilot_api.clarify_spec_v1_copilot_clarify_spec_post.asyncio",
@@ -138,7 +141,8 @@ def test_clarify_spec_no_response(self, client, clarify_spec_input, mock_api_key
138141
def test_clarify_spec_validation_error(
139142
self, client, clarify_spec_input, mock_api_key
140143
):
141-
mock_error = HTTPValidationError(detail=[])
144+
mock_error = MagicMock(spec=HTTPValidationError)
145+
mock_error.to_dict.return_value = {"detail": []}
142146

143147
with patch(
144148
"app.desktop.studio_server.copilot_api.clarify_spec_v1_copilot_clarify_spec_post.asyncio",
@@ -163,10 +167,11 @@ def test_refine_spec_no_api_key(self, client, refine_spec_input):
163167
assert "API key not configured" in response.json()["detail"]
164168

165169
def test_refine_spec_success(self, client, refine_spec_input, mock_api_key):
166-
mock_output = RefineSpecOutput(
167-
new_proposed_spec_edits=RefineSpecOutputNewProposedSpecEdits(),
168-
out_of_scope_feedback="No out of scope feedback",
169-
)
170+
mock_output = MagicMock(spec=RefineSpecOutput)
171+
mock_output.to_dict.return_value = {
172+
"new_proposed_spec_edits": {},
173+
"out_of_scope_feedback": "No out of scope feedback",
174+
}
170175

171176
with patch(
172177
"app.desktop.studio_server.copilot_api.refine_spec_v1_copilot_refine_spec_post.asyncio",
@@ -192,7 +197,8 @@ def test_refine_spec_no_response(self, client, refine_spec_input, mock_api_key):
192197
def test_refine_spec_validation_error(
193198
self, client, refine_spec_input, mock_api_key
194199
):
195-
mock_error = HTTPValidationError(detail=[])
200+
mock_error = MagicMock(spec=HTTPValidationError)
201+
mock_error.to_dict.return_value = {"detail": []}
196202

197203
with patch(
198204
"app.desktop.studio_server.copilot_api.refine_spec_v1_copilot_refine_spec_post.asyncio",
@@ -219,9 +225,8 @@ def test_generate_batch_no_api_key(self, client, generate_batch_input):
219225
assert "API key not configured" in response.json()["detail"]
220226

221227
def test_generate_batch_success(self, client, generate_batch_input, mock_api_key):
222-
mock_output = GenerateBatchOutput(
223-
data_by_topic=GenerateBatchOutputDataByTopic(),
224-
)
228+
mock_output = MagicMock(spec=GenerateBatchOutput)
229+
mock_output.to_dict.return_value = {"data_by_topic": {}}
225230

226231
with patch(
227232
"app.desktop.studio_server.copilot_api.generate_batch_v1_copilot_generate_batch_post.asyncio",
@@ -252,7 +257,8 @@ def test_generate_batch_no_response(
252257
def test_generate_batch_validation_error(
253258
self, client, generate_batch_input, mock_api_key
254259
):
255-
mock_error = HTTPValidationError(detail=[])
260+
mock_error = MagicMock(spec=HTTPValidationError)
261+
mock_error.to_dict.return_value = {"detail": []}
256262

257263
with patch(
258264
"app.desktop.studio_server.copilot_api.generate_batch_v1_copilot_generate_batch_post.asyncio",
@@ -269,9 +275,8 @@ def test_generate_batch_with_scoring(
269275
self, client, generate_batch_input, mock_api_key
270276
):
271277
generate_batch_input["enable_scoring"] = True
272-
mock_output = GenerateBatchOutput(
273-
data_by_topic=GenerateBatchOutputDataByTopic(),
274-
)
278+
mock_output = MagicMock(spec=GenerateBatchOutput)
279+
mock_output.to_dict.return_value = {"data_by_topic": {}}
275280

276281
with patch(
277282
"app.desktop.studio_server.copilot_api.generate_batch_v1_copilot_generate_batch_post.asyncio",

app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/review_spec/+page.svelte

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import CheckCircleIcon from "$lib/ui/icons/check_circle_icon.svelte"
2020
import ExclaimCircleIcon from "$lib/ui/icons/exclaim_circle_icon.svelte"
2121
import SpecAnalyzingAnimation from "../spec_analyzing_animation.svelte"
22+
import { client } from "$lib/api_client"
23+
import { load_task } from "$lib/stores"
24+
import { buildDefinitionFromProperties } from "../select_template/spec_templates"
2225
2326
$: project_id = $page.params.project_id
2427
$: task_id = $page.params.task_id
@@ -65,9 +68,6 @@
6568
}
6669
6770
onMount(async () => {
68-
// Wait 3 seconds to simulate loading time
69-
await new Promise((resolve) => setTimeout(resolve, 3000))
70-
7171
await load_spec_data()
7272
})
7373
@@ -84,33 +84,53 @@
8484
property_values = { ...formData.property_values }
8585
evaluate_full_trace = formData.evaluate_full_trace
8686
87-
// Generate mock review data (in a real implementation, this would come from an API)
88-
review_rows = [
89-
{
90-
id: "1",
91-
input: "User uploads a PDF document",
92-
output: "Document successfully processed",
93-
model_decision: "meets_spec",
94-
meets_spec: null,
95-
feedback: "",
96-
},
97-
{
98-
id: "2",
99-
input: "User tries to upload an invalid file",
100-
output: "Error: Invalid file format",
101-
model_decision: "fails_spec",
102-
meets_spec: null,
103-
feedback: "",
104-
},
105-
{
106-
id: "3",
107-
input: "User requests a summary of uploaded file",
108-
output: "Summary generated successfully",
109-
model_decision: "meets_spec",
110-
meets_spec: null,
111-
feedback: "",
87+
// Load the task to get instruction and schemas
88+
const task = await load_task(project_id, task_id)
89+
if (!task) {
90+
throw new Error("Failed to load task")
91+
}
92+
93+
const spec_definition = buildDefinitionFromProperties(
94+
spec_type,
95+
property_values,
96+
)
97+
98+
// TODO: Create a few shot prompt instead of basic prompt
99+
// TODO: What should task input/output schemas be exactly? Especially for plaintext tasks?
100+
const { data, error } = await client.POST("/api/copilot/clarify_spec", {
101+
body: {
102+
task_prompt_with_few_shot: task.instruction,
103+
task_input_schema: task.input_json_schema
104+
? JSON.stringify(task.input_json_schema)
105+
: "",
106+
task_output_schema: task.output_json_schema
107+
? JSON.stringify(task.output_json_schema)
108+
: "",
109+
spec_rendered_prompt_template: spec_definition,
110+
num_samples_per_topic: 10,
111+
num_topics: 5,
112+
num_exemplars: 10,
112113
},
113-
]
114+
})
115+
116+
if (error) {
117+
throw error
118+
}
119+
120+
if (!data) {
121+
throw new Error(
122+
"Failed to analyze spec for review. Please try again.",
123+
)
124+
}
125+
126+
review_rows = data.examples_for_feedback.map((example, index) => ({
127+
id: String(index + 1),
128+
input: example.input,
129+
output: example.output,
130+
model_decision: example.exhibits_issue ? "fails_spec" : "meets_spec",
131+
meets_spec: null,
132+
feedback: "",
133+
}))
114134
115135
// Don't clear the stored data - keep it for back navigation
116136
// It will be cleared when the spec is successfully created

0 commit comments

Comments
 (0)