diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..66d9cf6823 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,7 +52,13 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import ( + Literal, + TypeAlias, + _AnnotatedAlias, + deprecated, + get_origin, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -653,13 +659,24 @@ def get_sqlalchemy_type(field: Any) -> Any: return sa_type type_ = get_sa_type_from_field(field) - metadata = get_field_metadata(field) + if isinstance(type_, _AnnotatedAlias): + class_to_compare = type_.__origin__ + if len(type_.__metadata__) == 1: + metadata = get_field_metadata(type_.__metadata__[0]) + else: + # not sure if this is the right behavior + raise ValueError( + f"AnnotatedAlias with multiple metadata is not supported: {type_}" + ) + else: + class_to_compare = type_ + metadata = get_field_metadata(field) # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(type_, Enum): + if issubclass(class_to_compare, Enum): return sa_Enum(type_) if issubclass( - type_, + class_to_compare, ( str, ipaddress.IPv4Address, @@ -674,28 +691,28 @@ def get_sqlalchemy_type(field: Any) -> Any: if max_length: return AutoString(length=max_length) return AutoString - if issubclass(type_, float): + if issubclass(class_to_compare, float): return Float - if issubclass(type_, bool): + if issubclass(class_to_compare, bool): return Boolean - if issubclass(type_, int): + if issubclass(class_to_compare, int): return Integer - if issubclass(type_, datetime): + if issubclass(class_to_compare, datetime): return DateTime - if issubclass(type_, date): + if issubclass(class_to_compare, date): return Date - if issubclass(type_, timedelta): + if issubclass(class_to_compare, timedelta): return Interval - if issubclass(type_, time): + if issubclass(class_to_compare, time): return Time - if issubclass(type_, bytes): + if issubclass(class_to_compare, bytes): return LargeBinary - if issubclass(type_, Decimal): + if issubclass(class_to_compare, Decimal): return Numeric( precision=getattr(metadata, "max_digits", None), scale=getattr(metadata, "decimal_places", None), ) - if issubclass(type_, uuid.UUID): + if issubclass(class_to_compare, uuid.UUID): return Uuid raise ValueError(f"{type_} has no matching SQLAlchemy type") diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..3b99833c9c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from decimal import Decimal +from typing import Annotated, List, Optional import pytest from sqlalchemy.exc import IntegrityError @@ -125,3 +126,25 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_optional_annotated_decimal(): + class Model(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + dec: Annotated[Decimal, Field(max_digits=4, decimal_places=2)] | None = None + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(model := Model(dec=Decimal("3.14"))) + session.commit() + session.refresh(model) + assert model.dec == Decimal("3.14") + + with Session(engine) as session: + session.add(model := Model(dec=Decimal("3.142"))) + session.commit() + session.refresh(model) + assert model.dec == Decimal("3.14")