From bef160292635b7057641ff368a507342db08c3fa Mon Sep 17 00:00:00 2001 From: Paul Roever Date: Tue, 26 Mar 2024 16:11:56 +0000 Subject: [PATCH 1/3] support passing fields as annotated types --- sqlmodel/main.py | 49 +++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9e8330d69d..dfba3dbe01 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -51,7 +51,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time -from typing_extensions import Literal, deprecated, get_origin +from typing_extensions import Literal, _AnnotatedAlias, deprecated, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -561,48 +561,59 @@ def get_sqlalchemy_type(field: Any) -> Any: return sa_type type_ = get_type_from_field(field) - metadata = get_field_metadata(field) + if isinstance(type_, _AnnotatedAlias): + class_to_compare = type_.__origin__ + if len(type_.__metadata__) == 1: + metadata = get_field_metadata(type_.__metadata__[0]) + else: + # not sure if this is the right behavior + raise ValueError( + f"AnnotatedAlias with multiple metadata is not supported: {type_}" + ) + else: + class_to_compare = type_ + metadata = get_field_metadata(field) # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(type_, Enum): + if issubclass(class_to_compare, Enum): return sa_Enum(type_) - if issubclass(type_, str): + if issubclass(class_to_compare, str): max_length = getattr(metadata, "max_length", None) if max_length: return AutoString(length=max_length) return AutoString - if issubclass(type_, float): + if issubclass(class_to_compare, float): return Float - if issubclass(type_, bool): + if issubclass(class_to_compare, bool): return Boolean - if issubclass(type_, int): + if issubclass(class_to_compare, int): return Integer - if issubclass(type_, datetime): + if issubclass(class_to_compare, datetime): return DateTime - if issubclass(type_, date): + if issubclass(class_to_compare, date): return Date - if issubclass(type_, timedelta): + if issubclass(class_to_compare, timedelta): return Interval - if issubclass(type_, time): + if issubclass(class_to_compare, time): return Time - if issubclass(type_, bytes): + if issubclass(class_to_compare, bytes): return LargeBinary - if issubclass(type_, Decimal): + if issubclass(class_to_compare, Decimal): return Numeric( precision=getattr(metadata, "max_digits", None), scale=getattr(metadata, "decimal_places", None), ) - if issubclass(type_, ipaddress.IPv4Address): + if issubclass(class_to_compare, ipaddress.IPv4Address): return AutoString - if issubclass(type_, ipaddress.IPv4Network): + if issubclass(class_to_compare, ipaddress.IPv4Network): return AutoString - if issubclass(type_, ipaddress.IPv6Address): + if issubclass(class_to_compare, ipaddress.IPv6Address): return AutoString - if issubclass(type_, ipaddress.IPv6Network): + if issubclass(class_to_compare, ipaddress.IPv6Network): return AutoString - if issubclass(type_, Path): + if issubclass(class_to_compare, Path): return AutoString - if issubclass(type_, uuid.UUID): + if issubclass(class_to_compare, uuid.UUID): return GUID raise ValueError(f"{type_} has no matching SQLAlchemy type") From 605291fce2458f75808cf3382fa026b493b7dcc2 Mon Sep 17 00:00:00 2001 From: Paul Roever Date: Tue, 26 Mar 2024 16:13:46 +0000 Subject: [PATCH 2/3] add test for optional annotated decimal field --- tests/test_main.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..3b99833c9c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from decimal import Decimal +from typing import Annotated, List, Optional import pytest from sqlalchemy.exc import IntegrityError @@ -125,3 +126,25 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_optional_annotated_decimal(): + class Model(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + dec: Annotated[Decimal, Field(max_digits=4, decimal_places=2)] | None = None + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(model := Model(dec=Decimal("3.14"))) + session.commit() + session.refresh(model) + assert model.dec == Decimal("3.14") + + with Session(engine) as session: + session.add(model := Model(dec=Decimal("3.142"))) + session.commit() + session.refresh(model) + assert model.dec == Decimal("3.14") From 12368da428b0ef5153935fa91d4bd63a1f8e2414 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 17:34:24 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 19af8fb2cd..66d9cf6823 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,7 +52,13 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, _AnnotatedAlias, deprecated, get_origin +from typing_extensions import ( + Literal, + TypeAlias, + _AnnotatedAlias, + deprecated, + get_origin, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2,