diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..f8bbb46e7b 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -221,7 +221,13 @@ def get_field_metadata(field: Any) -> Any: return FakeMetadata() def post_init_field_info(field_info: FieldInfo) -> None: - return None + if IS_PYDANTIC_V2: + if field_info.alias and not field_info.validation_alias: + field_info.validation_alias = field_info.alias + if field_info.alias and not field_info.serialization_alias: + field_info.serialization_alias = field_info.alias + else: + field_info._validate() # type: ignore[attr-defined] # Dummy to make it importable def _calculate_keys( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..e9b732a369 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -215,6 +215,8 @@ def Field( *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, + validation_alias: Optional[str] = None, + serialization_alias: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, exclude: Union[ @@ -260,6 +262,8 @@ def Field( *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, + validation_alias: Optional[str] = None, + serialization_alias: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, exclude: Union[ @@ -314,6 +318,8 @@ def Field( *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, + validation_alias: Optional[str] = None, + serialization_alias: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, exclude: Union[ @@ -349,6 +355,8 @@ def Field( *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, + validation_alias: Optional[str] = None, + serialization_alias: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, exclude: Union[ @@ -387,43 +395,60 @@ def Field( schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + field_info_kwargs = { + "alias": alias, + "title": title, + "description": description, + "exclude": exclude, + "include": include, + "const": const, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "multiple_of": multiple_of, + "max_digits": max_digits, + "decimal_places": decimal_places, + "min_items": min_items, + "max_items": max_items, + "unique_items": unique_items, + "min_length": min_length, + "max_length": max_length, + "allow_mutation": allow_mutation, + "regex": regex, + "discriminator": discriminator, + "repr": repr, + "primary_key": primary_key, + "foreign_key": foreign_key, + "ondelete": ondelete, + "unique": unique, + "nullable": nullable, + "index": index, + "sa_type": sa_type, + "sa_column": sa_column, + "sa_column_args": sa_column_args, + "sa_column_kwargs": sa_column_kwargs, + **current_schema_extra, + } + if IS_PYDANTIC_V2: + # Add Pydantic v2 specific parameters + field_info_kwargs.update( + { + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + } + ) + else: + if validation_alias: + raise RuntimeError("validation_alias is not supported in Pydantic v1") + if serialization_alias: + raise RuntimeError("serialization_alias is not supported in Pydantic v1") field_info = FieldInfo( default, default_factory=default_factory, - alias=alias, - title=title, - description=description, - exclude=exclude, - include=include, - const=const, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - max_digits=max_digits, - decimal_places=decimal_places, - min_items=min_items, - max_items=max_items, - unique_items=unique_items, - min_length=min_length, - max_length=max_length, - allow_mutation=allow_mutation, - regex=regex, - discriminator=discriminator, - repr=repr, - primary_key=primary_key, - foreign_key=foreign_key, - ondelete=ondelete, - unique=unique, - nullable=nullable, - index=index, - sa_type=sa_type, - sa_column=sa_column, - sa_column_args=sa_column_args, - sa_column_kwargs=sa_column_kwargs, - **current_schema_extra, + **field_info_kwargs, ) + post_init_field_info(field_info) return field_info diff --git a/tests/test_aliases.py b/tests/test_aliases.py new file mode 100644 index 0000000000..ea32002a95 --- /dev/null +++ b/tests/test_aliases.py @@ -0,0 +1,178 @@ +from typing import Type, Union + +import pytest +from pydantic import VERSION, BaseModel, ValidationError +from pydantic import Field as PField +from sqlmodel import Field, SQLModel + +from tests.conftest import needs_pydanticv2 + +""" +Alias tests for SQLModel and Pydantic compatibility +""" + + +class PydanticUser(BaseModel): + full_name: str = PField(alias="fullName") + + +class SQLModelUser(SQLModel): + full_name: str = Field(alias="fullName") + + +# Models with config (validate_by_name=True) + + +if VERSION.startswith("2."): + + class PydanticUserWithConfig(PydanticUser): + model_config = {"validate_by_name": True} + + class SQLModelUserWithConfig(SQLModelUser): + model_config = {"validate_by_name": True} + +else: + + class PydanticUserWithConfig(PydanticUser): + class Config: + allow_population_by_field_name = True + + class SQLModelUserWithConfig(SQLModelUser): + class Config: + allow_population_by_field_name = True + + +@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser]) +def test_create_with_field_name(model: Union[Type[PydanticUser], Type[SQLModelUser]]): + with pytest.raises(ValidationError): + model(full_name="Alice") + + +@pytest.mark.parametrize("model", [PydanticUserWithConfig, SQLModelUserWithConfig]) +def test_create_with_field_name_with_config( + model: Union[Type[PydanticUserWithConfig], Type[SQLModelUserWithConfig]], +): + user = model(full_name="Alice") + assert user.full_name == "Alice" + + +@pytest.mark.parametrize( + "model", + [PydanticUser, SQLModelUser, PydanticUserWithConfig, SQLModelUserWithConfig], +) +def test_create_with_alias( + model: Union[ + Type[PydanticUser], + Type[SQLModelUser], + Type[PydanticUserWithConfig], + Type[SQLModelUserWithConfig], + ], +): + user = model(fullName="Bob") # using alias + assert user.full_name == "Bob" + + +@pytest.mark.parametrize("model", [PydanticUserWithConfig, SQLModelUserWithConfig]) +def test_create_with_both_prefers_alias( + model: Union[Type[PydanticUserWithConfig], Type[SQLModelUserWithConfig]], +): + user = model(full_name="IGNORED", fullName="Charlie") + assert user.full_name == "Charlie" # alias should take precedence + + +@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser]) +def test_dict_default_uses_field_names( + model: Union[Type[PydanticUser], Type[SQLModelUser]], +): + user = model(fullName="Dana") + if VERSION.startswith("2."): + data = user.model_dump() + else: + data = user.dict() + assert "full_name" in data + assert "fullName" not in data + assert data["full_name"] == "Dana" + + +@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser]) +def test_dict_default_uses_aliases( + model: Union[Type[PydanticUser], Type[SQLModelUser]], +): + user = model(fullName="Dana") + if VERSION.startswith("2."): + data = user.model_dump(by_alias=True) + else: + data = user.dict(by_alias=True) + assert "fullName" in data + assert "full_name" not in data + assert data["fullName"] == "Dana" + + +@pytest.mark.parametrize("model", [PydanticUser, SQLModelUser]) +def test_json_by_alias( + model: Union[Type[PydanticUser], Type[SQLModelUser]], +): + user = model(fullName="Frank") + if VERSION.startswith("2."): + json_data = user.model_dump_json(by_alias=True) + else: + json_data = user.json(by_alias=True) + assert ('"fullName":"Frank"' in json_data) or ('"fullName": "Frank"' in json_data) + assert "full_name" not in json_data + + +if VERSION.startswith("2."): + + class PydanticUserV2(BaseModel): + first_name: str = PField( + validation_alias="firstName", serialization_alias="f_name" + ) + + class SQLModelUserV2(SQLModel): + first_name: str = Field( + validation_alias="firstName", serialization_alias="f_name" + ) +else: + # Dummy classes for Pydantic v1 to prevent import errors + PydanticUserV2 = None + SQLModelUserV2 = None + + +def test_validation_alias_runtimeerror_pydantic_v1(): + if VERSION.startswith("2."): + pytest.skip("Only relevant for Pydantic v1") + with pytest.raises( + RuntimeError, match="validation_alias is not supported in Pydantic v1" + ): + Field(validation_alias="foo") + + +def test_serialization_alias_runtimeerror_pydantic_v1(): + if VERSION.startswith("2."): + pytest.skip("Only relevant for Pydantic v1") + with pytest.raises( + RuntimeError, match="serialization_alias is not supported in Pydantic v1" + ): + Field(serialization_alias="bar") + + +@needs_pydanticv2 +@pytest.mark.parametrize("model", [PydanticUserV2, SQLModelUserV2]) +def test_create_with_validation_alias( + model: Union[Type[PydanticUserV2], Type[SQLModelUserV2]], +): + user = model(firstName="John") + assert user.first_name == "John" + + +@needs_pydanticv2 +@pytest.mark.parametrize("model", [PydanticUserV2, SQLModelUserV2]) +def test_serialize_with_serialization_alias( + model: Union[Type[PydanticUserV2], Type[SQLModelUserV2]], +): + user = model(firstName="Jane") + data = user.model_dump(by_alias=True) + assert "f_name" in data + assert "firstName" not in data + assert "first_name" not in data + assert data["f_name"] == "Jane"