diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..c9c9f81986 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,7 +1,9 @@ from __future__ import annotations +import inspect as inspect_module import ipaddress import uuid +import warnings import weakref from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -28,6 +30,7 @@ ) from pydantic import BaseModel, EmailStr +from pydantic import Field as PydanticField from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, @@ -54,7 +57,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import Annotated, Literal, TypeAlias, deprecated, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -100,6 +103,10 @@ ] OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] +FIELD_ACCEPTED_KWARGS = set(inspect_module.signature(PydanticField).parameters.keys()) +if "schema_extra" in FIELD_ACCEPTED_KWARGS: + FIELD_ACCEPTED_KWARGS.remove("schema_extra") + def __dataclass_transform__( *, @@ -251,7 +258,16 @@ def Field( sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, - schema_extra: Optional[Dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[Dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -297,7 +313,16 @@ def Field( sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, - schema_extra: Optional[Dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[Dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -343,7 +368,16 @@ def Field( discriminator: Optional[str] = None, repr: bool = True, sa_column: Union[Column[Any], UndefinedType] = Undefined, - schema_extra: Optional[Dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[Dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -387,9 +421,47 @@ def Field( sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, - schema_extra: Optional[Dict[str, Any]] = None, + schema_extra: Annotated[ + Optional[Dict[str, Any]], + deprecated( + """ + This parameter is deprecated. + Use `json_schema_extra` to add extra information to JSON schema. + """ + ), + ] = None, + json_schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: + if schema_extra: + warnings.warn( + "schema_extra parameter is deprecated. " + "Use json_schema_extra to add extra information to JSON schema.", + DeprecationWarning, + stacklevel=1, + ) + + field_info_kwargs = {} + current_json_schema_extra = json_schema_extra or {} current_schema_extra = schema_extra or {} + + if IS_PYDANTIC_V2: + # Handle a workaround when json_schema_extra was passed via schema_extra + if "json_schema_extra" in current_schema_extra: + json_schema_extra_from_schema_extra = current_schema_extra.pop( + "json_schema_extra" + ) + if not current_json_schema_extra: + current_json_schema_extra = json_schema_extra_from_schema_extra + # Split parameters from schema_extra to field_info_kwargs and json_schema_extra + for key, value in current_schema_extra.items(): + if key in FIELD_ACCEPTED_KWARGS: + field_info_kwargs[key] = value + else: + current_json_schema_extra[key] = value + field_info_kwargs["json_schema_extra"] = current_json_schema_extra + else: + field_info_kwargs.update(current_json_schema_extra or current_schema_extra) + field_info = FieldInfo( default, default_factory=default_factory, @@ -425,7 +497,7 @@ def Field( sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, - **current_schema_extra, + **field_info_kwargs, ) post_init_field_info(field_info) return field_info diff --git a/tests/test_field_json_schema_extra.py b/tests/test_field_json_schema_extra.py new file mode 100644 index 0000000000..4d0c4ede1e --- /dev/null +++ b/tests/test_field_json_schema_extra.py @@ -0,0 +1,100 @@ +import pytest +from sqlmodel import Field, SQLModel +from sqlmodel._compat import IS_PYDANTIC_V2 + +from tests.conftest import needs_pydanticv2 + + +def test_json_schema_extra_applied(): + """test json_schema_extra is applied to the field""" + + class Item(SQLModel): + name: str = Field( + json_schema_extra={ + "example": "Sword of Power", + "x-custom-key": "Important Data", + } + ) + + if IS_PYDANTIC_V2: + schema = Item.model_json_schema() + else: + schema = Item.schema() + + name_schema = schema["properties"]["name"] + + assert name_schema["example"] == "Sword of Power" + assert name_schema["x-custom-key"] == "Important Data" + + +def test_schema_extra_and_json_schema_extra_conflict(): + """ + Test that passing schema_extra and json_schema_extra at the same time produces + a warning. + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + Field(schema_extra={"legacy": 1}, json_schema_extra={"new": 2}) + + +def test_schema_extra_backward_compatibility(): + """ + test that schema_extra is backward compatible with json_schema_extra + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + + class LegacyItem(SQLModel): + name: str = Field( + schema_extra={ + "example": "Sword of Old", + "x-custom-key": "Important Data", + "serialization_alias": "id_test", + } + ) + + if IS_PYDANTIC_V2: + schema = LegacyItem.model_json_schema() + else: + schema = LegacyItem.schema() + + name_schema = schema["properties"]["name"] + + assert name_schema["example"] == "Sword of Old" + assert name_schema["x-custom-key"] == "Important Data" + + if IS_PYDANTIC_V2: + # With Pydantic V1 serialization_alias from schema_extra is applied + field_info = LegacyItem.model_fields["name"] + assert field_info.serialization_alias == "id_test" + else: # With Pydantic V1 it just goes to schema + assert name_schema["serialization_alias"] == "id_test" + + +@needs_pydanticv2 +def test_json_schema_extra_mix_in_schema_extra(): + """ + Test workaround when json_schema_extra was passed via schema_extra with Pydantic v2. + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + + class Item(SQLModel): + name: str = Field( + schema_extra={ + "json_schema_extra": { + "example": "Sword of Power", + "x-custom-key": "Important Data", + }, + "serialization_alias": "id_test", + } + ) + + schema = Item.model_json_schema() + + name_schema = schema["properties"]["name"] + assert name_schema["example"] == "Sword of Power" + assert name_schema["x-custom-key"] == "Important Data" + + field_info = Item.model_fields["name"] + assert field_info.serialization_alias == "id_test"