Skip to content

🐛 Fix TypeError for fields annotated with Literal #1439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
7 changes: 6 additions & 1 deletion sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dict,
ForwardRef,
Generator,
Literal,
Mapping,
Optional,
Set,
Expand All @@ -23,6 +24,8 @@
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, get_args, get_origin

from .sql.sqltypes import AutoString

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2])
Expand Down Expand Up @@ -459,7 +462,9 @@ def is_field_noneable(field: "FieldInfo") -> bool:
return field.allow_none # type: ignore[no-any-return, attr-defined]

def get_sa_type_from_field(field: Any) -> Any:
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
if get_origin(field.type_) is Literal:
return AutoString
elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
return field.type_
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")

Expand Down
3 changes: 3 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@ def get_sqlalchemy_type(field: Any) -> Any:
type_ = get_sa_type_from_field(field)
metadata = get_field_metadata(field)

# Checks for `Literal` type annotation
if type_ is Literal:
return AutoString
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if issubclass(type_, Enum):
return sa_Enum(type_)
Expand Down
24 changes: 23 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Literal, Optional

import pytest
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -125,3 +125,25 @@ class Hero(SQLModel, table=True):
# The next statement should not raise an AttributeError
assert hero_rusty_man.team
assert hero_rusty_man.team.name == "Preventers"


def test_literal_typehints_are_treated_as_strings(clear_sqlmodel):
"""Test https://github.com/fastapi/sqlmodel/issues/57"""

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(unique=True)
weakness: Literal["Kryptonite", "Dehydration", "Munchies"]

superguy = Hero(name="Superguy", weakness="Kryptonite")

engine = create_engine("sqlite://", echo=True)

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(superguy)
session.commit()
session.refresh(superguy)
assert superguy.weakness == "Kryptonite"
assert isinstance(superguy.weakness, str)
Loading