Skip to content

Commit 43878b5

Browse files
committed
Support Annotated[..., Field()]
1 parent a85de91 commit 43878b5

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

sqlmodel/main.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,15 @@
5454
from sqlalchemy.orm.instrumentation import is_instrumented
5555
from sqlalchemy.sql.schema import MetaData
5656
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
57-
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
57+
from typing_extensions import (
58+
Annotated as TEAnnotated,
59+
Literal,
60+
TypeAlias,
61+
deprecated,
62+
get_args as te_get_args,
63+
get_origin,
64+
)
65+
from typing import Annotated as TypingAnnotated
5866

5967
from ._compat import ( # type: ignore[attr-defined]
6068
IS_PYDANTIC_V2,
@@ -546,6 +554,26 @@ def __new__(
546554
**new_cls.__annotations__,
547555
}
548556

557+
# For Pydantic v2: If a field used Annotated[..., Field(sa_column=Column(...))]
558+
# Pydantic might not lift our custom attribute onto the final FieldInfo.
559+
# Recover it from the original annotations before creating SQLAlchemy Columns.
560+
if IS_PYDANTIC_V2:
561+
for field_name, ann in original_annotations.items():
562+
try:
563+
origin = get_origin(ann)
564+
if origin in (TEAnnotated, TypingAnnotated):
565+
for extra in te_get_args(ann)[1:]:
566+
sa_col = getattr(extra, "sa_column", Undefined)
567+
if isinstance(sa_col, Column):
568+
# Attach found Column to the Pydantic field info
569+
model_fields = get_model_fields(new_cls)
570+
if field_name in model_fields:
571+
setattr(model_fields[field_name], "sa_column", sa_col)
572+
break
573+
except Exception:
574+
# Best-effort; fall back to default behavior
575+
pass
576+
549577
def get_config(name: str) -> Any:
550578
config_class_value = get_config_value(
551579
model=new_cls, parameter=name, default=Undefined
@@ -562,6 +590,26 @@ def get_config(name: str) -> Any:
562590
# If it was passed by kwargs, ensure it's also set in config
563591
set_config_value(model=new_cls, parameter="table", value=config_table)
564592
for k, v in get_model_fields(new_cls).items():
593+
# Prefer a Column passed via Annotated[..., Field(sa_column=...)]
594+
if IS_PYDANTIC_V2:
595+
ann = original_annotations.get(k, None)
596+
if ann is not None:
597+
try:
598+
origin = get_origin(ann)
599+
if origin in (TEAnnotated, TypingAnnotated):
600+
for extra in te_get_args(ann)[1:]:
601+
sa_col = getattr(extra, "sa_column", Undefined)
602+
if isinstance(sa_col, Column):
603+
setattr(new_cls, k, sa_col)
604+
break
605+
else:
606+
# no Column override found, build normally
607+
col = get_column_from_field(v)
608+
setattr(new_cls, k, col)
609+
continue
610+
except Exception:
611+
# Fall back to normal column building
612+
pass
565613
col = get_column_from_field(v)
566614
setattr(new_cls, k, col)
567615
# Set a config flag to tell FastAPI that this should be read with a field
@@ -709,6 +757,27 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
709757
else:
710758
field_info = field.field_info
711759
sa_column = getattr(field_info, "sa_column", Undefined)
760+
# In Pydantic v2, when using Annotated[T, Field(...)], the Field(...) object
761+
# is stored in the field's metadata and some custom attributes (like
762+
# sa_column) might not be lifted onto the main FieldInfo. Inspect metadata
763+
# to honor a Column passed via Annotated Field(...).
764+
if IS_PYDANTIC_V2 and not isinstance(sa_column, Column):
765+
# Try to recover a Column passed via Annotated[..., Field(sa_column=...)]
766+
raw_ann = getattr(field, "annotation", None)
767+
origin = get_origin(raw_ann)
768+
if origin in (TEAnnotated, TypingAnnotated):
769+
for extra in te_get_args(raw_ann)[1:]:
770+
meta_sa_column = getattr(extra, "sa_column", Undefined)
771+
if isinstance(meta_sa_column, Column):
772+
sa_column = meta_sa_column
773+
break
774+
# Also check field metadata in case custom FieldInfo leaked through
775+
if not isinstance(sa_column, Column):
776+
for meta in getattr(field, "metadata", ()):
777+
meta_sa_column = getattr(meta, "sa_column", Undefined)
778+
if isinstance(meta_sa_column, Column):
779+
sa_column = meta_sa_column
780+
break
712781
if isinstance(sa_column, Column):
713782
return sa_column
714783
sa_type = get_sqlalchemy_type(field)

tests/test_field_sa_column.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import Optional
2+
from datetime import datetime
3+
from typing_extensions import Annotated
24

35
import pytest
4-
from sqlalchemy import Column, Integer, String
6+
from sqlalchemy import Column, Integer, String, DateTime
57
from sqlmodel import Field, SQLModel
8+
from tests.conftest import needs_pydanticv2
69

710

811
def test_sa_column_takes_precedence() -> None:
@@ -119,3 +122,16 @@ class Item(SQLModel, table=True):
119122
sa_column=Column(Integer, primary_key=True),
120123
ondelete="CASCADE",
121124
)
125+
126+
127+
@needs_pydanticv2
128+
def test_sa_column_in_annotated_is_respected() -> None:
129+
class Item(SQLModel, table=True):
130+
id: Optional[int] = Field(default=None, primary_key=True)
131+
available_at: Annotated[
132+
datetime, Field(sa_column=Column(DateTime(timezone=True)))
133+
]
134+
135+
# Should reflect timezone=True from the provided Column
136+
assert isinstance(Item.available_at.type, DateTime) # type: ignore[attr-defined]
137+
assert Item.available_at.type.timezone is True # type: ignore[attr-defined]

0 commit comments

Comments
 (0)