Skip to content

Commit 0fec5e1

Browse files
committed
copilot_api
1 parent acfcf91 commit 0fec5e1

File tree

5 files changed

+521
-8
lines changed

5 files changed

+521
-8
lines changed

app/desktop/desktop_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from kiln_ai.utils.logging import setup_litellm_logging
1515

1616
from app.desktop.log_config import log_config
17+
from app.desktop.studio_server.copilot_api import connect_copilot_api
1718
from app.desktop.studio_server.data_gen_api import connect_data_gen_api
1819
from app.desktop.studio_server.dev_tools import connect_dev_tools
1920
from app.desktop.studio_server.eval_api import connect_evals_api
@@ -64,6 +65,7 @@ def make_app(tk_root: tk.Tk | None = None):
6465
connect_evals_api(app)
6566
connect_import_api(app, tk_root=tk_root)
6667
connect_tool_servers_api(app)
68+
connect_copilot_api(app)
6769
connect_dev_tools(app)
6870

6971
# Important: webhost must be last, it handles all other URLs

app/desktop/studio_server/api_client/kiln_server_client.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import os
22
from importlib.metadata import version
33

4+
import httpx
5+
from app.desktop.studio_server.api_client.kiln_ai_server_client.client import (
6+
AuthenticatedClient,
7+
)
48
from app.desktop.studio_server.api_client.kiln_ai_server_client.client import (
59
Client as KilnServerClient,
610
)
@@ -14,15 +18,35 @@ def _get_desktop_app_version() -> str:
1418
return "unknown"
1519

1620

17-
def get_kiln_server_client() -> KilnServerClient:
21+
def _get_base_url() -> str:
22+
"""Get the base URL for the Kiln server."""
23+
return os.getenv("KILN_SERVER_BASE_URL", "https://api.kiln.tech")
24+
25+
26+
def _get_common_headers() -> dict[str, str]:
27+
"""Get common headers for all Kiln server requests."""
1828
app_version = _get_desktop_app_version()
19-
base_url = os.getenv("KILN_SERVER_BASE_URL", "https://api.kiln.tech")
29+
return {
30+
"User-Agent": f"KilnDesktopApp/{app_version}",
31+
"Kiln-Desktop-App-Version": app_version,
32+
}
33+
34+
35+
def get_kiln_server_client() -> KilnServerClient:
36+
"""Get a non-authenticated client for the Kiln server."""
2037
return KilnServerClient(
21-
base_url=base_url,
22-
headers={
23-
"User-Agent": f"KilnDesktopApp/{app_version}",
24-
"Kiln-Desktop-App-Version": app_version,
25-
},
38+
base_url=_get_base_url(),
39+
headers=_get_common_headers(),
40+
)
41+
42+
43+
def get_authenticated_client(api_key: str) -> AuthenticatedClient:
44+
"""Get an authenticated client for the Kiln server with the provided API key."""
45+
return AuthenticatedClient(
46+
base_url=_get_base_url(),
47+
token=api_key,
48+
headers=_get_common_headers(),
49+
timeout=httpx.Timeout(timeout=300.0, connect=30.0),
2650
)
2751

2852

app/desktop/studio_server/api_client/test_kiln_server_client.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44

55
import httpx
66
import pytest
7-
87
from app.desktop.studio_server.api_client.kiln_ai_server_client.api.health import (
98
ping_ping_get,
109
)
10+
from app.desktop.studio_server.api_client.kiln_ai_server_client.client import (
11+
AuthenticatedClient,
12+
)
1113
from app.desktop.studio_server.api_client.kiln_server_client import (
1214
KilnServerClient,
1315
_get_desktop_app_version,
16+
get_authenticated_client,
1417
get_kiln_server_client,
1518
server_client,
1619
)
@@ -162,3 +165,44 @@ async def test_async_ping_request(self, mock_async_transport):
162165
result = await ping_ping_get.asyncio(client=client)
163166

164167
assert result == "pong"
168+
169+
170+
class TestGetAuthenticatedClient:
171+
"""Tests for the get_authenticated_client factory function."""
172+
173+
def test_returns_authenticated_client(self):
174+
"""Verify the function returns an AuthenticatedClient instance."""
175+
client = get_authenticated_client("test_api_key")
176+
assert isinstance(client, AuthenticatedClient)
177+
178+
def test_returns_client_with_correct_base_url_no_env_var(self):
179+
"""Verify the client is configured with the correct base URL."""
180+
with patch.dict(os.environ, {}, clear=True):
181+
client = get_authenticated_client("test_api_key")
182+
assert client._base_url == "https://api.kiln.tech"
183+
184+
def test_returns_client_with_correct_base_url_with_env_var(self):
185+
"""Verify the client is configured with the correct base URL from env."""
186+
with patch.dict(
187+
os.environ, {"KILN_SERVER_BASE_URL": "https://localhost:8000"}, clear=True
188+
):
189+
client = get_authenticated_client("test_api_key")
190+
assert client._base_url == "https://localhost:8000"
191+
192+
def test_returns_client_with_correct_token(self):
193+
"""Verify the client has the correct token set."""
194+
client = get_authenticated_client("my_secret_token")
195+
assert client.token == "my_secret_token"
196+
197+
def test_returns_client_with_correct_user_agent_header(self):
198+
"""Verify the client has the correct User-Agent header set."""
199+
client = get_authenticated_client("test_api_key")
200+
assert client._headers["User-Agent"] == f"KilnDesktopApp/{APP_VERSION}"
201+
assert client._headers["Kiln-Desktop-App-Version"] == APP_VERSION
202+
203+
def test_client_has_appropriate_timeout(self):
204+
"""Verify the client has a reasonable timeout for long-running requests."""
205+
client = get_authenticated_client("test_api_key")
206+
assert client._timeout is not None
207+
assert client._timeout.read == 300.0
208+
assert client._timeout.connect == 30.0
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from typing import Any
2+
3+
from app.desktop.studio_server.api_client.kiln_ai_server_client.api.copilot import (
4+
clarify_spec_v1_copilot_clarify_spec_post,
5+
generate_batch_v1_copilot_generate_batch_post,
6+
refine_spec_v1_copilot_refine_spec_post,
7+
)
8+
from app.desktop.studio_server.api_client.kiln_ai_server_client.models import (
9+
ClarifySpecInput,
10+
ClarifySpecOutput,
11+
ExampleWithFeedback,
12+
GenerateBatchInput,
13+
GenerateBatchOutput,
14+
HTTPValidationError,
15+
RefineSpecInput,
16+
RefineSpecOutput,
17+
SpecInfo,
18+
TaskInfo,
19+
)
20+
from app.desktop.studio_server.api_client.kiln_server_client import (
21+
get_authenticated_client,
22+
)
23+
from fastapi import FastAPI, HTTPException
24+
from kiln_ai.utils.config import Config
25+
from pydantic import BaseModel, Field
26+
27+
28+
class ClarifySpecApiInput(BaseModel):
29+
task_prompt_with_few_shot: str
30+
task_input_schema: str
31+
task_output_schema: str
32+
spec_rendered_prompt_template: str
33+
num_samples_per_topic: int
34+
num_topics: int
35+
num_exemplars: int = Field(default=10)
36+
37+
38+
class RefineSpecApiInput(BaseModel):
39+
task_prompt_with_few_shot: str
40+
task_input_schema: str
41+
task_output_schema: str
42+
task_info: dict[str, Any]
43+
spec: dict[str, Any]
44+
examples_with_feedback: list[dict[str, Any]]
45+
46+
47+
class GenerateBatchApiInput(BaseModel):
48+
task_prompt_with_few_shot: str
49+
task_input_schema: str
50+
task_output_schema: str
51+
spec_rendered_prompt_template: str
52+
num_samples_per_topic: int
53+
num_topics: int
54+
enable_scoring: bool = Field(default=False)
55+
56+
57+
def _get_api_key() -> str:
58+
"""Get the Kiln Copilot API key from config, raising an error if not set."""
59+
api_key = Config.shared().kiln_copilot_api_key
60+
if not api_key:
61+
raise HTTPException(
62+
status_code=401,
63+
detail="Kiln Copilot API key not configured. Please connect your API key in settings.",
64+
)
65+
return api_key
66+
67+
68+
def connect_copilot_api(app: FastAPI):
69+
@app.post("/api/copilot/clarify_spec")
70+
async def clarify_spec(input: ClarifySpecApiInput) -> dict[str, Any]:
71+
api_key = _get_api_key()
72+
client = get_authenticated_client(api_key)
73+
74+
clarify_input = ClarifySpecInput(
75+
task_prompt_with_few_shot=input.task_prompt_with_few_shot,
76+
task_input_schema=input.task_input_schema,
77+
task_output_schema=input.task_output_schema,
78+
spec_rendered_prompt_template=input.spec_rendered_prompt_template,
79+
num_samples_per_topic=input.num_samples_per_topic,
80+
num_topics=input.num_topics,
81+
num_exemplars=input.num_exemplars,
82+
)
83+
84+
result = await clarify_spec_v1_copilot_clarify_spec_post.asyncio(
85+
client=client,
86+
body=clarify_input,
87+
)
88+
89+
if result is None:
90+
raise HTTPException(
91+
status_code=500, detail="Failed to clarify spec: No response"
92+
)
93+
94+
if isinstance(result, HTTPValidationError):
95+
raise HTTPException(
96+
status_code=422,
97+
detail=f"Validation error: {result.to_dict()}",
98+
)
99+
100+
if isinstance(result, ClarifySpecOutput):
101+
return result.to_dict()
102+
103+
raise HTTPException(
104+
status_code=500,
105+
detail=f"Failed to clarify spec: Unexpected response type {type(result)}",
106+
)
107+
108+
@app.post("/api/copilot/refine_spec")
109+
async def refine_spec(input: RefineSpecApiInput) -> dict[str, Any]:
110+
api_key = _get_api_key()
111+
client = get_authenticated_client(api_key)
112+
113+
task_info = TaskInfo.from_dict(input.task_info)
114+
spec = SpecInfo.from_dict(input.spec)
115+
examples_with_feedback = [
116+
ExampleWithFeedback.from_dict(ex) for ex in input.examples_with_feedback
117+
]
118+
119+
refine_input = RefineSpecInput(
120+
task_prompt_with_few_shot=input.task_prompt_with_few_shot,
121+
task_input_schema=input.task_input_schema,
122+
task_output_schema=input.task_output_schema,
123+
task_info=task_info,
124+
spec=spec,
125+
examples_with_feedback=examples_with_feedback,
126+
)
127+
128+
result = await refine_spec_v1_copilot_refine_spec_post.asyncio(
129+
client=client,
130+
body=refine_input,
131+
)
132+
133+
if result is None:
134+
raise HTTPException(
135+
status_code=500, detail="Failed to refine spec: No response"
136+
)
137+
138+
if isinstance(result, HTTPValidationError):
139+
raise HTTPException(
140+
status_code=422,
141+
detail=f"Validation error: {result.to_dict()}",
142+
)
143+
144+
if isinstance(result, RefineSpecOutput):
145+
return result.to_dict()
146+
147+
raise HTTPException(
148+
status_code=500,
149+
detail=f"Failed to refine spec: Unexpected response type {type(result)}",
150+
)
151+
152+
@app.post("/api/copilot/generate_batch")
153+
async def generate_batch(input: GenerateBatchApiInput) -> dict[str, Any]:
154+
api_key = _get_api_key()
155+
client = get_authenticated_client(api_key)
156+
157+
generate_input = GenerateBatchInput(
158+
task_prompt_with_few_shot=input.task_prompt_with_few_shot,
159+
task_input_schema=input.task_input_schema,
160+
task_output_schema=input.task_output_schema,
161+
spec_rendered_prompt_template=input.spec_rendered_prompt_template,
162+
num_samples_per_topic=input.num_samples_per_topic,
163+
num_topics=input.num_topics,
164+
enable_scoring=input.enable_scoring,
165+
)
166+
167+
result = await generate_batch_v1_copilot_generate_batch_post.asyncio(
168+
client=client,
169+
body=generate_input,
170+
)
171+
172+
if result is None:
173+
raise HTTPException(
174+
status_code=500, detail="Failed to generate batch: No response"
175+
)
176+
177+
if isinstance(result, HTTPValidationError):
178+
raise HTTPException(
179+
status_code=422,
180+
detail=f"Validation error: {result.to_dict()}",
181+
)
182+
183+
if isinstance(result, GenerateBatchOutput):
184+
return result.to_dict()
185+
186+
raise HTTPException(
187+
status_code=500,
188+
detail=f"Failed to generate batch: Unexpected response type {type(result)}",
189+
)

0 commit comments

Comments
 (0)