Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, Literal, get_args, get_origin

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
Expand Down Expand Up @@ -459,6 +459,8 @@ def is_field_noneable(field: "FieldInfo") -> bool:
return field.allow_none # type: ignore[no-any-return, attr-defined]

def get_sa_type_from_field(field: Any) -> Any:
if get_origin(field.type_) is Literal:
return Literal
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
return field.type_
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
Expand Down
3 changes: 3 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@ def get_sqlalchemy_type(field: Any) -> Any:
type_ = get_sa_type_from_field(field)
metadata = get_field_metadata(field)

# Checks for `Literal` type annotation
if type_ is Literal:
return AutoString
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if issubclass(type_, Enum):
return sa_Enum(type_)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import RelationshipProperty
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
from typing_extensions import Literal


def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):
Expand Down Expand Up @@ -125,3 +126,25 @@ class Hero(SQLModel, table=True):
# The next statement should not raise an AttributeError
assert hero_rusty_man.team
assert hero_rusty_man.team.name == "Preventers"


def test_literal_typehints_are_treated_as_strings(clear_sqlmodel):
"""Test https://github.com/fastapi/sqlmodel/issues/57"""

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(unique=True)
weakness: Literal["Kryptonite", "Dehydration", "Munchies"]

superguy = Hero(name="Superguy", weakness="Kryptonite")

engine = create_engine("sqlite://", echo=True)

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(superguy)
session.commit()
session.refresh(superguy)
assert superguy.weakness == "Kryptonite"
assert isinstance(superguy.weakness, str)
Loading