Skip to content
5 changes: 3 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import (
Mapped,
MappedColumn,
RelationshipProperty,
declared_attr,
registry,
Expand Down Expand Up @@ -701,13 +702,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
if isinstance(sa_column, (Column, MappedColumn)):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
Expand Down
51 changes: 31 additions & 20 deletions tests/test_field_sa_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,85 @@

import pytest
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import mapped_column
from sqlmodel import Field, SQLModel


def test_sa_column_takes_precedence() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_takes_precedence(clear_sqlmodel, column_class) -> None:
class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column=Column(String, primary_key=True, nullable=False),
sa_column=column_class(String, primary_key=True, nullable=False),
)

# It would have been nullable with no sa_column
assert Item.id.nullable is False # type: ignore
assert isinstance(Item.id.type, String) # type: ignore


def test_sa_column_no_sa_args() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_sa_args(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column_args=[Integer],
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_sa_kargs() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_sa_kargs(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column_kwargs={"primary_key": True},
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_type() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_type(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_type=Integer,
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_primary_key() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_primary_key(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
primary_key=True,
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_nullable() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_nullable(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
nullable=True,
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_foreign_key() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_foreign_key(clear_sqlmodel, column_class) -> None:
with pytest.raises(RuntimeError):

class Team(SQLModel, table=True):
Expand All @@ -84,38 +92,41 @@ class Hero(SQLModel, table=True):
team_id: Optional[int] = Field(
default=None,
foreign_key="team.id",
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_unique() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_unique(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
unique=True,
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_index() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_index(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
index=True,
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
)


def test_sa_column_no_ondelete() -> None:
@pytest.mark.parametrize("column_class", [Column, mapped_column])
def test_sa_column_no_ondelete(column_class) -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_column=Column(Integer, primary_key=True),
sa_column=column_class(Integer, primary_key=True),
ondelete="CASCADE",
)
Loading