diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..60cb6f4000 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -10,6 +10,7 @@ Dict, ForwardRef, Generator, + Literal, Mapping, Optional, Set, @@ -23,6 +24,8 @@ from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin +from .sql.sqltypes import AutoString + # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2]) @@ -459,7 +462,9 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none # type: ignore[no-any-return, attr-defined] def get_sa_type_from_field(field: Any) -> Any: - if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + if get_origin(field.type_) is Literal: + return AutoString + elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..404d1efd0d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -655,6 +655,9 @@ def get_sqlalchemy_type(field: Any) -> Any: type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) + # Checks for `Literal` type annotation + if type_ is Literal: + return AutoString # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): return sa_Enum(type_) diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..5416bfc666 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Literal, Optional import pytest from sqlalchemy.exc import IntegrityError @@ -125,3 +125,25 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_literal_typehints_are_treated_as_strings(clear_sqlmodel): + """Test https://github.com/fastapi/sqlmodel/issues/57""" + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(unique=True) + weakness: Literal["Kryptonite", "Dehydration", "Munchies"] + + superguy = Hero(name="Superguy", weakness="Kryptonite") + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(superguy) + session.commit() + session.refresh(superguy) + assert superguy.weakness == "Kryptonite" + assert isinstance(superguy.weakness, str)