Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/lmstudio/_kv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def prediction_config_to_kv_config_stack(
response_format: Type[ModelSchema] | DictSchema | None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
for_text_completion: bool = False,
) -> KvConfigStack:
) -> tuple[bool, KvConfigStack]:
dict_config: DictObject
if config is None:
dict_config = {}
Expand All @@ -279,6 +279,7 @@ def prediction_config_to_kv_config_stack(
dict_config = LlmPredictionConfig._from_any_dict(config).to_dict()
response_schema: DictSchema | None = None
if response_format is not None:
structured = True
if "structured" in dict_config:
raise LMStudioValueError(
"Cannot specify both 'response_format' in API call and 'structured' in config"
Expand All @@ -289,6 +290,15 @@ def prediction_config_to_kv_config_stack(
response_schema = response_format.model_json_schema()
else:
response_schema = response_format
else:
# The response schema may also be passed in via the config
# (doing it this way type hints as an unstructured result,
# but we still allow it at runtime for consistency with JS)
match dict_config:
case {"structured": {"type": "json"}}:
structured = True
case _:
structured = False
fields = _to_kv_config_stack_base(
dict_config,
"llm",
Expand All @@ -308,7 +318,7 @@ def prediction_config_to_kv_config_stack(
additional_layers: list[KvConfigStackLayerDict] = []
if for_text_completion:
additional_layers.append(_get_completion_config_layer())
return _api_override_kv_config_stack(fields, additional_layers)
return structured, _api_override_kv_config_stack(fields, additional_layers)


def _get_completion_config_layer() -> KvConfigStackLayerDict:
Expand Down
6 changes: 3 additions & 3 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def __init__(
config["rawTools"] = llm_tools.to_dict()
else:
config.raw_tools = llm_tools
config_stack = self._make_config_override(response_format, config)
structured, config_stack = self._make_config_override(response_format, config)
params = PredictionChannelRequest._from_api_dict(
{
"modelSpecifier": _model_spec_to_api_dict(model_specifier),
Expand All @@ -1155,7 +1155,7 @@ def __init__(
super().__init__(params)
# Status tracking for the prediction progress and result reporting
self._is_cancelled = False
self._structured = response_format is not None
self._structured = structured
self._on_message = on_message
self._prompt_processing_progress = -1.0
self._on_prompt_processing_progress = on_prompt_processing_progress
Expand All @@ -1172,7 +1172,7 @@ def _make_config_override(
cls,
response_format: Type[ModelSchema] | DictSchema | None,
config: LlmPredictionConfig | LlmPredictionConfigDict | None,
) -> KvConfigStack:
) -> tuple[bool, KvConfigStack]:
return prediction_config_to_kv_config_stack(
response_format, config, **cls._additional_config_options()
)
Expand Down
20 changes: 16 additions & 4 deletions src/lmstudio/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,23 @@ def model_json_schema(cls) -> DictSchema:
"useFp16ForKvCache": "useFp16ForKVCache",
}

_SKIP_FIELD_RECURSION = set(
(
"json_schema",
"jsonSchema",
)
)


def _snake_case_to_camelCase(key: str) -> str:
first, *rest = key.split("_")
camelCase = "".join((first, *(w.capitalize() for w in rest)))
return _CAMEL_CASE_OVERRIDES.get(camelCase, camelCase)


# TODO: Rework this conversion to be based on the API struct definitions
# * Only recurse into fields that allow substructs
# * Only check fields with a snake case -> camel case name conversion
def _snake_case_keys_to_camelCase(data: DictObject) -> DictObject:
translated_data: dict[str, Any] = {}
dicts_to_process = [(data, translated_data)]
Expand All @@ -114,12 +124,14 @@ def _queue_dict(input_dict: DictObject, output_dict: dict[str, Any]) -> None:

for input_dict, output_dict in dicts_to_process:
for k, v in input_dict.items():
new_value: Any
match v:
case {}:
new_dict: dict[str, Any] = {}
_queue_dict(v, new_dict)
new_value = new_dict
if k in _SKIP_FIELD_RECURSION:
new_value = v
else:
new_dict: dict[str, Any] = {}
_queue_dict(v, new_dict)
new_value = new_dict
case [*_]:
new_list: list[Any] = []
for item in v:
Expand Down
35 changes: 33 additions & 2 deletions tests/async/test_inference_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Chat,
DictSchema,
LlmInfo,
LlmPredictionConfigDict,
LlmPredictionFragment,
LlmPredictionStats,
LMStudioModelNotFoundError,
Expand All @@ -27,6 +28,8 @@
EXPECTED_LLM_ID,
PROMPT,
RESPONSE_FORMATS,
RESPONSE_SCHEMA,
SCHEMA_FIELDS,
SHORT_PREDICTION_CONFIG,
check_sdk_error,
)
Expand Down Expand Up @@ -93,7 +96,7 @@ async def test_complete_stream_async(caplog: LogCap) -> None:
@pytest.mark.asyncio
@pytest.mark.lmstudio
@pytest.mark.parametrize("format_type", RESPONSE_FORMATS)
async def test_complete_structured_async(
async def test_complete_response_format_async(
format_type: Type[ModelSchema] | DictSchema, caplog: LogCap
) -> None:
prompt = PROMPT
Expand All @@ -107,7 +110,35 @@ async def test_complete_structured_async(
assert isinstance(response.content, str)
assert isinstance(response.parsed, dict)
assert response.parsed == json.loads(response.content)
assert "response" in response.parsed
assert SCHEMA_FIELDS.keys() == response.parsed.keys()


@pytest.mark.asyncio
@pytest.mark.lmstudio
async def test_complete_structured_config_async(caplog: LogCap) -> None:
prompt = PROMPT
caplog.set_level(logging.DEBUG)
model_id = EXPECTED_LLM_ID
async with AsyncClient() as client:
llm = await client.llm.model(model_id)
config: LlmPredictionConfigDict = {
# snake_case keys are accepted at runtime,
# but the type hinted spelling is the camelCase names
# This test case checks the schema field name is converted,
# but *not* the snake_case and camelCase field names in the
# schema itself
"structured": {
"type": "json",
"json_schema": RESPONSE_SCHEMA,
} # type: ignore[typeddict-item]
}
response = await llm.complete(prompt, config=config)
assert isinstance(response, PredictionResult)
logging.info(f"LLM response: {response!r}")
assert isinstance(response.content, str)
assert isinstance(response.parsed, dict)
assert response.parsed == json.loads(response.content)
assert SCHEMA_FIELDS.keys() == response.parsed.keys()


@pytest.mark.asyncio
Expand Down
30 changes: 22 additions & 8 deletions tests/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,34 @@
# Structured LLM responses
####################################################

# Schema includes both snake_case and camelCase field
# names to ensure the special-casing of snake_case
# fields in dict inputs doesn't corrupt schema inputs
SCHEMA_FIELDS = {
"response": {
"type": "string",
},
"first_word_in_response": {
"type": "string",
},
"lastWordInResponse": {
"type": "string",
},
}
SCHEMA_FIELD_NAMES = list(SCHEMA_FIELDS.keys())

SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"required": ["response"],
"properties": {
"response": {
"type": "string",
}
},
"required": SCHEMA_FIELD_NAMES,
"properties": SCHEMA_FIELDS,
"additionalProperties": False,
}
RESPONSE_SCHEMA = {
"$defs": {
"schema": {
"properties": {"response": {"type": "string"}},
"required": ["response"],
"properties": SCHEMA_FIELDS,
"required": SCHEMA_FIELD_NAMES,
"title": "schema",
"type": "object",
}
Expand All @@ -114,6 +126,8 @@ def model_json_schema(cls) -> DictSchema:

class LMStudioResponseFormat(BaseModel):
response: str
first_word_in_response: str
lastWordInResponse: str


RESPONSE_FORMATS = (LMStudioResponseFormat, OtherResponseFormat, SCHEMA)
Expand Down
34 changes: 32 additions & 2 deletions tests/sync/test_inference_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Chat,
DictSchema,
LlmInfo,
LlmPredictionConfigDict,
LlmPredictionFragment,
LlmPredictionStats,
LMStudioModelNotFoundError,
Expand All @@ -34,6 +35,8 @@
EXPECTED_LLM_ID,
PROMPT,
RESPONSE_FORMATS,
RESPONSE_SCHEMA,
SCHEMA_FIELDS,
SHORT_PREDICTION_CONFIG,
check_sdk_error,
)
Expand Down Expand Up @@ -96,7 +99,7 @@ def test_complete_stream_sync(caplog: LogCap) -> None:

@pytest.mark.lmstudio
@pytest.mark.parametrize("format_type", RESPONSE_FORMATS)
def test_complete_structured_sync(
def test_complete_response_format_sync(
format_type: Type[ModelSchema] | DictSchema, caplog: LogCap
) -> None:
prompt = PROMPT
Expand All @@ -110,7 +113,34 @@ def test_complete_structured_sync(
assert isinstance(response.content, str)
assert isinstance(response.parsed, dict)
assert response.parsed == json.loads(response.content)
assert "response" in response.parsed
assert SCHEMA_FIELDS.keys() == response.parsed.keys()


@pytest.mark.lmstudio
def test_complete_structured_config_sync(caplog: LogCap) -> None:
prompt = PROMPT
caplog.set_level(logging.DEBUG)
model_id = EXPECTED_LLM_ID
with Client() as client:
llm = client.llm.model(model_id)
config: LlmPredictionConfigDict = {
# snake_case keys are accepted at runtime,
# but the type hinted spelling is the camelCase names
# This test case checks the schema field name is converted,
# but *not* the snake_case and camelCase field names in the
# schema itself
"structured": {
"type": "json",
"json_schema": RESPONSE_SCHEMA,
} # type: ignore[typeddict-item]
}
response = llm.complete(prompt, config=config)
assert isinstance(response, PredictionResult)
logging.info(f"LLM response: {response!r}")
assert isinstance(response.content, str)
assert isinstance(response.parsed, dict)
assert response.parsed == json.loads(response.content)
assert SCHEMA_FIELDS.keys() == response.parsed.keys()


@pytest.mark.lmstudio
Expand Down
3 changes: 2 additions & 1 deletion tests/test_kv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def test_kv_stack_load_config_llm(config_dict: DictObject) -> None:
def test_kv_stack_prediction_config(config_dict: DictObject) -> None:
# MyPy complains here that it can't be sure the dict has all the right keys
# It is correct about that, but we want to ensure it is handled at runtime
kv_stack = prediction_config_to_kv_config_stack(None, config_dict) # type: ignore[arg-type]
structured, kv_stack = prediction_config_to_kv_config_stack(None, config_dict) # type: ignore[arg-type]
assert structured
assert kv_stack.to_dict() == EXPECTED_KV_STACK_PREDICTION


Expand Down
Loading