Skip to content

Commit c1250eb

Browse files
authored
Simplify model parameter REST API (#1475)
* Simplifying the model parameter rest api * implemented pydantic model and fixed type inputs * fixing functionality of pydantic model * simple code revision * Fixing parameter response passing * only using parameters in the response
1 parent c4859ae commit c1250eb

File tree

3 files changed

+40
-55
lines changed

3 files changed

+40
-55
lines changed

packages/jupyter-ai/jupyter_ai/model_providers/parameter_schemas.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, cast
3-
4-
if TYPE_CHECKING:
5-
from typing import Any
6-
2+
from typing import Literal, Any
3+
from pydantic import BaseModel
74

85
PARAMETER_SCHEMAS: dict[str, dict[str, Any]] = {
96
"temperature": {
10-
"type": "number",
7+
"type": "float",
118
"min": 0,
129
"max": 2,
1310
"description": "Controls randomness in the output. Lower values make it more focused and deterministic."
1411
},
1512
"top_p": {
16-
"type": "number",
13+
"type": "float",
1714
"min": 0,
1815
"max": 1,
1916
"description": "Nucleus sampling parameter. Consider tokens with top_p probability mass."
@@ -69,13 +66,13 @@
6966
# },
7067

7168
"presence_penalty": {
72-
"type": "number",
69+
"type": "float",
7370
"min": -2,
7471
"max": 2,
7572
"description": "Penalize new tokens based on whether they appear in the text so far."
7673
},
7774
"frequency_penalty": {
78-
"type": "number",
75+
"type": "float",
7976
"min": -2,
8077
"max": 2,
8178
"description": "Penalize new tokens based on their frequency in the text so far."
@@ -110,26 +107,35 @@
110107
}
111108
}
112109

113-
def get_parameter_schema(param_name: str) -> dict[str, Any]:
110+
class ParameterSchema(BaseModel):
111+
"""Pydantic model for parameter schema definition."""
112+
type: Literal['boolean', 'integer', 'float', 'string', 'array', 'object']
113+
description: str
114+
115+
class GetModelParametersResponse(BaseModel):
116+
"""Pydantic model for GET model parameters response."""
117+
parameters: dict[str, ParameterSchema]
118+
parameter_names: list[str]
119+
120+
class UpdateModelParametersResponse(BaseModel):
121+
"""Pydantic model for PUT model parameters response."""
122+
parameters: dict[str, Any]
123+
124+
def get_parameter_schema(param_name: str) -> ParameterSchema:
114125
"""
115126
Get the schema for a specific parameter.
116-
117-
TODO: Define a Pydantic model for the parameter schema, e.g.
118-
`ParameterSchema`. Update the return type annotation to `ParameterSchema`.
119127
"""
120128
schema = PARAMETER_SCHEMAS.get(param_name)
121129
if schema is None:
122-
return {
123-
"type": "string",
124-
"description": f"Parameter {param_name} (schema not defined)"
125-
}
126-
return schema
130+
return ParameterSchema(
131+
type="string",
132+
description=f"Parameter {param_name} (schema not defined)"
133+
)
134+
return ParameterSchema(**schema)
127135

128-
def get_parameters_with_schemas(param_names: list[str]) -> dict[str, Any]:
136+
def get_parameters_with_schemas(param_names: list[str]) -> dict[str, ParameterSchema]:
129137
"""
130138
Get schemas for a list of parameter names.
131-
132-
TODO: Update the return type annotation to `dict[str, ParameterSchema]`.
133139
"""
134140
return {
135141
name: get_parameter_schema(name)

packages/jupyter-ai/jupyter_ai/model_providers/parameters_rest_api.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44

55
from litellm.litellm_core_utils.get_supported_openai_params import get_supported_openai_params
6-
from .parameter_schemas import get_parameters_with_schemas, coerce_parameter_value
6+
from .parameter_schemas import get_parameters_with_schemas, coerce_parameter_value, GetModelParametersResponse, UpdateModelParametersResponse
77
from ..config_manager import ConfigManager
88
from ..config import UpdateConfigRequest
99

@@ -56,23 +56,14 @@ def get(self):
5656
# Get parameter schemas with types, defaults, and descriptions
5757
parameters_with_schemas = get_parameters_with_schemas(parameter_names)
5858

59-
# TODO: Define the response type as a Pydantic model to prevent
60-
# breaking API changes, e.g. `GetModelParametersResponse`.
61-
#
62-
# TODO: replace 'number' with 'float' in parameter schemas for
63-
# clarity. 'number' is ambiguous as to whether it is an integer or
64-
# float. Make sure to update the frontend type.
65-
#
66-
# TODO: Drop 'count' from the response (and the corresponding types
67-
# across the frontend & backend).
68-
response = {
69-
"parameters": parameters_with_schemas,
70-
"parameter_names": parameter_names,
71-
"count": len(parameter_names)
72-
}
59+
# Create Pydantic response model
60+
response = GetModelParametersResponse(
61+
parameters=parameters_with_schemas,
62+
parameter_names=parameter_names
63+
)
7364

7465
self.set_header("Content-Type", "application/json")
75-
self.finish(json.dumps(response))
66+
self.finish(response.model_dump_json())
7667

7768
except Exception as e:
7869
self.log.exception("Failed to get model parameters")
@@ -127,21 +118,13 @@ def put(self):
127118
)
128119
config_manager.update_config(update_request)
129120

130-
# TODO: Determine what a response could be used for, and remove the
131-
# response if it is unnecessary. Right now the response body is
132-
# ignored by the frontend.
133-
#
134-
# TODO: Define the response type as a Pydantic model to prevent
135-
# breaking API changes, e.g. `UpdateModelParametersResponse`.
136-
response = {
137-
"status": "success",
138-
"message": f"Parameters saved for model {model_id}",
139-
"model_id": model_id,
140-
"parameters": coerced_parameters,
141-
}
121+
# Create Pydantic response model for API compatibility
122+
response = UpdateModelParametersResponse(
123+
parameters=coerced_parameters
124+
)
142125

143126
self.set_header("Content-Type", "application/json")
144-
self.finish(json.dumps(response))
127+
self.finish(response.model_dump_json())
145128

146129
except json.JSONDecodeError:
147130
raise HTTPError(400, "Invalid JSON in request body")

packages/jupyter-ai/src/handler.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,20 +200,16 @@ export namespace AiService {
200200
export type GetModelParametersResponse = {
201201
parameters: Record<string, ParameterSchema>;
202202
parameter_names: string[];
203-
count: number;
204203
};
205204

206205
export type ParameterSchema = {
207-
type: 'boolean' | 'integer' | 'number' | 'string' | 'array' | 'object';
206+
type: 'boolean' | 'integer' | 'float' | 'string' | 'array' | 'object';
208207
description: string;
209208
min?: number;
210209
max?: number;
211210
};
212211

213212
export type UpdateModelParametersResponse = {
214-
status: string;
215-
message: string;
216-
model_id: string;
217213
parameters: Record<string, any>;
218214
};
219215

0 commit comments

Comments
 (0)