diff --git a/src/lmstudio/_kv_config.py b/src/lmstudio/_kv_config.py index a3e7541..e558531 100644 --- a/src/lmstudio/_kv_config.py +++ b/src/lmstudio/_kv_config.py @@ -268,7 +268,7 @@ def prediction_config_to_kv_config_stack( response_format: Type[ModelSchema] | DictSchema | None, config: LlmPredictionConfig | LlmPredictionConfigDict | None, for_text_completion: bool = False, -) -> KvConfigStack: +) -> tuple[bool, KvConfigStack]: dict_config: DictObject if config is None: dict_config = {} @@ -279,6 +279,7 @@ def prediction_config_to_kv_config_stack( dict_config = LlmPredictionConfig._from_any_dict(config).to_dict() response_schema: DictSchema | None = None if response_format is not None: + structured = True if "structured" in dict_config: raise LMStudioValueError( "Cannot specify both 'response_format' in API call and 'structured' in config" @@ -289,6 +290,15 @@ def prediction_config_to_kv_config_stack( response_schema = response_format.model_json_schema() else: response_schema = response_format + else: + # The response schema may also be passed in via the config + # (doing it this way type hints as an unstructured result, + # but we still allow it at runtime for consistency with JS) + match dict_config: + case {"structured": {"type": "json"}}: + structured = True + case _: + structured = False fields = _to_kv_config_stack_base( dict_config, "llm", @@ -308,7 +318,7 @@ def prediction_config_to_kv_config_stack( additional_layers: list[KvConfigStackLayerDict] = [] if for_text_completion: additional_layers.append(_get_completion_config_layer()) - return _api_override_kv_config_stack(fields, additional_layers) + return structured, _api_override_kv_config_stack(fields, additional_layers) def _get_completion_config_layer() -> KvConfigStackLayerDict: diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 1464dbe..d8b92fa 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1144,7 +1144,7 @@ def __init__( config["rawTools"] = llm_tools.to_dict() else: config.raw_tools = llm_tools - config_stack = self._make_config_override(response_format, config) + structured, config_stack = self._make_config_override(response_format, config) params = PredictionChannelRequest._from_api_dict( { "modelSpecifier": _model_spec_to_api_dict(model_specifier), @@ -1155,7 +1155,7 @@ def __init__( super().__init__(params) # Status tracking for the prediction progress and result reporting self._is_cancelled = False - self._structured = response_format is not None + self._structured = structured self._on_message = on_message self._prompt_processing_progress = -1.0 self._on_prompt_processing_progress = on_prompt_processing_progress @@ -1172,7 +1172,7 @@ def _make_config_override( cls, response_format: Type[ModelSchema] | DictSchema | None, config: LlmPredictionConfig | LlmPredictionConfigDict | None, - ) -> KvConfigStack: + ) -> tuple[bool, KvConfigStack]: return prediction_config_to_kv_config_stack( response_format, config, **cls._additional_config_options() ) diff --git a/src/lmstudio/schemas.py b/src/lmstudio/schemas.py index 301ed2e..6c08281 100644 --- a/src/lmstudio/schemas.py +++ b/src/lmstudio/schemas.py @@ -93,6 +93,13 @@ def model_json_schema(cls) -> DictSchema: "useFp16ForKvCache": "useFp16ForKVCache", } +_SKIP_FIELD_RECURSION = set( + ( + "json_schema", + "jsonSchema", + ) +) + def _snake_case_to_camelCase(key: str) -> str: first, *rest = key.split("_") @@ -100,6 +107,9 @@ def _snake_case_to_camelCase(key: str) -> str: return _CAMEL_CASE_OVERRIDES.get(camelCase, camelCase) +# TODO: Rework this conversion to be based on the API struct definitions +# * Only recurse into fields that allow substructs +# * Only check fields with a snake case -> camel case name conversion def _snake_case_keys_to_camelCase(data: DictObject) -> DictObject: translated_data: dict[str, Any] = {} dicts_to_process = [(data, translated_data)] @@ -114,12 +124,14 @@ def _queue_dict(input_dict: DictObject, output_dict: dict[str, Any]) -> None: for input_dict, output_dict in dicts_to_process: for k, v in input_dict.items(): - new_value: Any match v: case {}: - new_dict: dict[str, Any] = {} - _queue_dict(v, new_dict) - new_value = new_dict + if k in _SKIP_FIELD_RECURSION: + new_value = v + else: + new_dict: dict[str, Any] = {} + _queue_dict(v, new_dict) + new_value = new_dict case [*_]: new_list: list[Any] = [] for item in v: diff --git a/tests/async/test_inference_async.py b/tests/async/test_inference_async.py index 7259a88..f653e89 100644 --- a/tests/async/test_inference_async.py +++ b/tests/async/test_inference_async.py @@ -15,6 +15,7 @@ Chat, DictSchema, LlmInfo, + LlmPredictionConfigDict, LlmPredictionFragment, LlmPredictionStats, LMStudioModelNotFoundError, @@ -27,6 +28,8 @@ EXPECTED_LLM_ID, PROMPT, RESPONSE_FORMATS, + RESPONSE_SCHEMA, + SCHEMA_FIELDS, SHORT_PREDICTION_CONFIG, check_sdk_error, ) @@ -93,7 +96,7 @@ async def test_complete_stream_async(caplog: LogCap) -> None: @pytest.mark.asyncio @pytest.mark.lmstudio @pytest.mark.parametrize("format_type", RESPONSE_FORMATS) -async def test_complete_structured_async( +async def test_complete_response_format_async( format_type: Type[ModelSchema] | DictSchema, caplog: LogCap ) -> None: prompt = PROMPT @@ -107,7 +110,35 @@ async def test_complete_structured_async( assert isinstance(response.content, str) assert isinstance(response.parsed, dict) assert response.parsed == json.loads(response.content) - assert "response" in response.parsed + assert SCHEMA_FIELDS.keys() == response.parsed.keys() + + +@pytest.mark.asyncio +@pytest.mark.lmstudio +async def test_complete_structured_config_async(caplog: LogCap) -> None: + prompt = PROMPT + caplog.set_level(logging.DEBUG) + model_id = EXPECTED_LLM_ID + async with AsyncClient() as client: + llm = await client.llm.model(model_id) + config: LlmPredictionConfigDict = { + # snake_case keys are accepted at runtime, + # but the type hinted spelling is the camelCase names + # This test case checks the schema field name is converted, + # but *not* the snake_case and camelCase field names in the + # schema itself + "structured": { + "type": "json", + "json_schema": RESPONSE_SCHEMA, + } # type: ignore[typeddict-item] + } + response = await llm.complete(prompt, config=config) + assert isinstance(response, PredictionResult) + logging.info(f"LLM response: {response!r}") + assert isinstance(response.content, str) + assert isinstance(response.parsed, dict) + assert response.parsed == json.loads(response.content) + assert SCHEMA_FIELDS.keys() == response.parsed.keys() @pytest.mark.asyncio diff --git a/tests/support/__init__.py b/tests/support/__init__.py index 93175ff..57aab78 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -82,22 +82,34 @@ # Structured LLM responses #################################################### +# Schema includes both snake_case and camelCase field +# names to ensure the special-casing of snake_case +# fields in dict inputs doesn't corrupt schema inputs +SCHEMA_FIELDS = { + "response": { + "type": "string", + }, + "first_word_in_response": { + "type": "string", + }, + "lastWordInResponse": { + "type": "string", + }, +} +SCHEMA_FIELD_NAMES = list(SCHEMA_FIELDS.keys()) + SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", - "required": ["response"], - "properties": { - "response": { - "type": "string", - } - }, + "required": SCHEMA_FIELD_NAMES, + "properties": SCHEMA_FIELDS, "additionalProperties": False, } RESPONSE_SCHEMA = { "$defs": { "schema": { - "properties": {"response": {"type": "string"}}, - "required": ["response"], + "properties": SCHEMA_FIELDS, + "required": SCHEMA_FIELD_NAMES, "title": "schema", "type": "object", } @@ -114,6 +126,8 @@ def model_json_schema(cls) -> DictSchema: class LMStudioResponseFormat(BaseModel): response: str + first_word_in_response: str + lastWordInResponse: str RESPONSE_FORMATS = (LMStudioResponseFormat, OtherResponseFormat, SCHEMA) diff --git a/tests/sync/test_inference_sync.py b/tests/sync/test_inference_sync.py index 1c1533e..b0bd21b 100644 --- a/tests/sync/test_inference_sync.py +++ b/tests/sync/test_inference_sync.py @@ -22,6 +22,7 @@ Chat, DictSchema, LlmInfo, + LlmPredictionConfigDict, LlmPredictionFragment, LlmPredictionStats, LMStudioModelNotFoundError, @@ -34,6 +35,8 @@ EXPECTED_LLM_ID, PROMPT, RESPONSE_FORMATS, + RESPONSE_SCHEMA, + SCHEMA_FIELDS, SHORT_PREDICTION_CONFIG, check_sdk_error, ) @@ -96,7 +99,7 @@ def test_complete_stream_sync(caplog: LogCap) -> None: @pytest.mark.lmstudio @pytest.mark.parametrize("format_type", RESPONSE_FORMATS) -def test_complete_structured_sync( +def test_complete_response_format_sync( format_type: Type[ModelSchema] | DictSchema, caplog: LogCap ) -> None: prompt = PROMPT @@ -110,7 +113,34 @@ def test_complete_structured_sync( assert isinstance(response.content, str) assert isinstance(response.parsed, dict) assert response.parsed == json.loads(response.content) - assert "response" in response.parsed + assert SCHEMA_FIELDS.keys() == response.parsed.keys() + + +@pytest.mark.lmstudio +def test_complete_structured_config_sync(caplog: LogCap) -> None: + prompt = PROMPT + caplog.set_level(logging.DEBUG) + model_id = EXPECTED_LLM_ID + with Client() as client: + llm = client.llm.model(model_id) + config: LlmPredictionConfigDict = { + # snake_case keys are accepted at runtime, + # but the type hinted spelling is the camelCase names + # This test case checks the schema field name is converted, + # but *not* the snake_case and camelCase field names in the + # schema itself + "structured": { + "type": "json", + "json_schema": RESPONSE_SCHEMA, + } # type: ignore[typeddict-item] + } + response = llm.complete(prompt, config=config) + assert isinstance(response, PredictionResult) + logging.info(f"LLM response: {response!r}") + assert isinstance(response.content, str) + assert isinstance(response.parsed, dict) + assert response.parsed == json.loads(response.content) + assert SCHEMA_FIELDS.keys() == response.parsed.keys() @pytest.mark.lmstudio diff --git a/tests/test_kv_config.py b/tests/test_kv_config.py index d493491..f9f2a30 100644 --- a/tests/test_kv_config.py +++ b/tests/test_kv_config.py @@ -419,7 +419,8 @@ def test_kv_stack_load_config_llm(config_dict: DictObject) -> None: def test_kv_stack_prediction_config(config_dict: DictObject) -> None: # MyPy complains here that it can't be sure the dict has all the right keys # It is correct about that, but we want to ensure it is handled at runtime - kv_stack = prediction_config_to_kv_config_stack(None, config_dict) # type: ignore[arg-type] + structured, kv_stack = prediction_config_to_kv_config_stack(None, config_dict) # type: ignore[arg-type] + assert structured assert kv_stack.to_dict() == EXPECTED_KV_STACK_PREDICTION