Skip to content

Commit 44e8de4

Browse files
committed
Fix support for Annotated fields by preserving the underlying sqlmodel metadata.
1 parent a85de91 commit 44e8de4

File tree

3 files changed

+175
-18
lines changed

3 files changed

+175
-18
lines changed

sqlmodel/main.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ipaddress
44
import uuid
55
import weakref
6+
from dataclasses import dataclass
67
from datetime import date, datetime, time, timedelta
78
from decimal import Decimal
89
from enum import Enum
@@ -347,6 +348,38 @@ def Field(
347348
) -> Any: ...
348349

349350

351+
@dataclass
352+
class FieldInfoMetadata:
353+
primary_key: Union[bool, UndefinedType] = Undefined
354+
nullable: Union[bool, UndefinedType] = Undefined
355+
foreign_key: Any = Undefined
356+
ondelete: Union[OnDeleteType, UndefinedType] = Undefined
357+
unique: Union[bool, UndefinedType] = Undefined
358+
index: Union[bool, UndefinedType] = Undefined
359+
sa_type: Union[Type[Any], UndefinedType] = Undefined
360+
sa_column: Union[Column[Any], UndefinedType] = Undefined
361+
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined
362+
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined
363+
364+
365+
def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]:
366+
metadata_items = getattr(field_info, "metadata", None)
367+
if metadata_items:
368+
for meta in metadata_items:
369+
if isinstance(meta, FieldInfoMetadata):
370+
return meta
371+
return None
372+
373+
374+
def _get_sqlmodel_field_value(
375+
field_info: Any, attribute: str, default: Any = Undefined
376+
) -> Any:
377+
metadata = _get_sqlmodel_field_metadata(field_info)
378+
if metadata is not None and hasattr(metadata, attribute):
379+
return getattr(metadata, attribute)
380+
return getattr(field_info, attribute, default)
381+
382+
350383
def Field(
351384
default: Any = Undefined,
352385
*,
@@ -427,6 +460,20 @@ def Field(
427460
sa_column_kwargs=sa_column_kwargs,
428461
**current_schema_extra,
429462
)
463+
field_metadata = FieldInfoMetadata(
464+
primary_key=primary_key,
465+
nullable=nullable,
466+
foreign_key=foreign_key,
467+
ondelete=ondelete,
468+
unique=unique,
469+
index=index,
470+
sa_type=sa_type,
471+
sa_column=sa_column,
472+
sa_column_args=sa_column_args,
473+
sa_column_kwargs=sa_column_kwargs,
474+
)
475+
if hasattr(field_info, "metadata"):
476+
field_info.metadata.append(field_metadata) # type: ignore[attr-defined]
430477
post_init_field_info(field_info)
431478
return field_info
432479

@@ -651,7 +698,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
651698
field_info = field
652699
else:
653700
field_info = field.field_info
654-
sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
701+
sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined)
655702
if sa_type is not Undefined:
656703
return sa_type
657704

@@ -708,39 +755,39 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
708755
field_info = field
709756
else:
710757
field_info = field.field_info
711-
sa_column = getattr(field_info, "sa_column", Undefined)
758+
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
712759
if isinstance(sa_column, Column):
713760
return sa_column
714761
sa_type = get_sqlalchemy_type(field)
715-
primary_key = getattr(field_info, "primary_key", Undefined)
762+
primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined)
716763
if primary_key is Undefined:
717764
primary_key = False
718-
index = getattr(field_info, "index", Undefined)
765+
index = _get_sqlmodel_field_value(field_info, "index", Undefined)
719766
if index is Undefined:
720767
index = False
721768
nullable = not primary_key and is_field_noneable(field)
722769
# Override derived nullability if the nullable property is set explicitly
723770
# on the field
724-
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
771+
field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined)
725772
if field_nullable is not Undefined:
726773
assert not isinstance(field_nullable, UndefinedType)
727774
nullable = field_nullable
728775
args = []
729-
foreign_key = getattr(field_info, "foreign_key", Undefined)
776+
foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined)
730777
if foreign_key is Undefined:
731778
foreign_key = None
732-
unique = getattr(field_info, "unique", Undefined)
779+
unique = _get_sqlmodel_field_value(field_info, "unique", Undefined)
733780
if unique is Undefined:
734781
unique = False
735782
if foreign_key:
736-
if field_info.ondelete == "SET NULL" and not nullable:
783+
ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined)
784+
if ondelete_value is Undefined:
785+
ondelete_value = None
786+
if ondelete_value == "SET NULL" and not nullable:
737787
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
738788
assert isinstance(foreign_key, str)
739-
ondelete = getattr(field_info, "ondelete", Undefined)
740-
if ondelete is Undefined:
741-
ondelete = None
742-
assert isinstance(ondelete, (str, type(None))) # for typing
743-
args.append(ForeignKey(foreign_key, ondelete=ondelete))
789+
assert isinstance(ondelete_value, (str, type(None))) # for typing
790+
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
744791
kwargs = {
745792
"primary_key": primary_key,
746793
"nullable": nullable,
@@ -754,10 +801,12 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
754801
sa_default = field_info.default
755802
if sa_default is not Undefined:
756803
kwargs["default"] = sa_default
757-
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
804+
sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined)
758805
if sa_column_args is not Undefined:
759806
args.extend(list(cast(Sequence[Any], sa_column_args)))
760-
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
807+
sa_column_kwargs = _get_sqlmodel_field_value(
808+
field_info, "sa_column_kwargs", Undefined
809+
)
761810
if sa_column_kwargs is not Undefined:
762811
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
763812
return Column(sa_type, *args, **kwargs) # type: ignore

tests/test_field_sa_column.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from sqlalchemy import Column, Integer, String
55
from sqlmodel import Field, SQLModel
6+
from typing_extensions import Annotated
67

78

89
def test_sa_column_takes_precedence() -> None:
@@ -17,6 +18,17 @@ class Item(SQLModel, table=True):
1718
assert isinstance(Item.id.type, String) # type: ignore
1819

1920

21+
def test_sa_column_with_annotated_metadata() -> None:
22+
class Item(SQLModel, table=True):
23+
id: Annotated[Optional[int], "meta"] = Field(
24+
default=None,
25+
sa_column=Column(String, primary_key=True, nullable=False),
26+
)
27+
28+
assert Item.id.nullable is False # type: ignore
29+
assert isinstance(Item.id.type, String) # type: ignore
30+
31+
2032
def test_sa_column_no_sa_args() -> None:
2133
with pytest.raises(RuntimeError):
2234

tests/test_main.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from typing import List, Optional
22

33
import pytest
4-
from sqlalchemy.exc import IntegrityError
5-
from sqlalchemy.orm import RelationshipProperty
6-
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
4+
from sqlalchemy.exc import IntegrityError
5+
from sqlalchemy.orm import RelationshipProperty
6+
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
7+
from typing_extensions import Annotated
8+
9+
from tests.conftest import needs_pydanticv2
710

811

912
def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):
@@ -125,3 +128,96 @@ class Hero(SQLModel, table=True):
125128
# The next statement should not raise an AttributeError
126129
assert hero_rusty_man.team
127130
assert hero_rusty_man.team.name == "Preventers"
131+
132+
133+
def test_composite_primary_key(clear_sqlmodel):
134+
class UserPermission(SQLModel, table=True):
135+
user_id: int = Field(primary_key=True)
136+
resource_id: int = Field(primary_key=True)
137+
permission: str
138+
139+
engine = create_engine("sqlite://")
140+
SQLModel.metadata.create_all(engine)
141+
142+
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
143+
assert pk_column_names == {"user_id", "resource_id"}
144+
145+
with Session(engine) as session:
146+
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
147+
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
148+
session.add(perm1)
149+
session.add(perm2)
150+
session.commit()
151+
152+
with pytest.raises(IntegrityError):
153+
with Session(engine) as session:
154+
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
155+
session.add(perm3)
156+
session.commit()
157+
158+
159+
@needs_pydanticv2
160+
def test_composite_primary_key_and_validator(clear_sqlmodel):
161+
from pydantic import AfterValidator
162+
163+
def validate_resource_id(value: int) -> int:
164+
if value < 1:
165+
raise ValueError("Resource ID must be positive")
166+
return value
167+
168+
class UserPermission(SQLModel, table=True):
169+
user_id: int = Field(primary_key=True)
170+
resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(
171+
primary_key=True
172+
)
173+
permission: str
174+
175+
engine = create_engine("sqlite://")
176+
SQLModel.metadata.create_all(engine)
177+
178+
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
179+
assert pk_column_names == {"user_id", "resource_id"}
180+
181+
with Session(engine) as session:
182+
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
183+
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
184+
session.add(perm1)
185+
session.add(perm2)
186+
session.commit()
187+
188+
with pytest.raises(IntegrityError):
189+
with Session(engine) as session:
190+
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
191+
session.add(perm3)
192+
session.commit()
193+
194+
195+
@needs_pydanticv2
196+
def test_foreign_key_ondelete_with_annotated(clear_sqlmodel):
197+
from pydantic import AfterValidator
198+
199+
def ensure_positive(value: int) -> int:
200+
if value < 0:
201+
raise ValueError("Team ID must be positive")
202+
return value
203+
204+
class Team(SQLModel, table=True):
205+
id: int = Field(primary_key=True)
206+
name: str
207+
208+
class Hero(SQLModel, table=True):
209+
id: int = Field(primary_key=True)
210+
team_id: Annotated[int, AfterValidator(ensure_positive)] = Field(
211+
foreign_key="team.id",
212+
ondelete="CASCADE",
213+
)
214+
name: str
215+
216+
engine = create_engine("sqlite://")
217+
SQLModel.metadata.create_all(engine)
218+
219+
team_id_column = Hero.__table__.c.team_id # type: ignore[attr-defined]
220+
foreign_keys = list(team_id_column.foreign_keys)
221+
assert len(foreign_keys) == 1
222+
assert foreign_keys[0].ondelete == "CASCADE"
223+
assert team_id_column.nullable is False

0 commit comments

Comments
 (0)