Skip to content

Commit 6d93a46

Browse files
author
John Lyu
committed
support sqlalchemy polymorphic
1 parent e86b5fc commit 6d93a46

File tree

3 files changed

+180
-8
lines changed

3 files changed

+180
-8
lines changed

sqlmodel/_compat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pydantic import VERSION as P_VERSION
2222
from pydantic import BaseModel
2323
from pydantic.fields import FieldInfo
24+
from sqlalchemy import inspect
2425
from typing_extensions import Annotated, get_args, get_origin
2526

2627
# Reassign variable to make it reexported for mypy
@@ -290,6 +291,19 @@ def sqlmodel_table_construct(
290291
if value is not Undefined:
291292
setattr(self_instance, key, value)
292293
# End SQLModel override
294+
# Override polymorphic_on default value
295+
mapper = inspect(cls)
296+
polymorphic_on = mapper.polymorphic_on
297+
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
298+
field_info = cls.model_fields.get(polymorphic_property.key)
299+
if field_info:
300+
v = values.get(polymorphic_property.key)
301+
# if model is inherited or polymorphic_on is not explicitly set
302+
# set the polymorphic_on by default
303+
if mapper.inherits or v is None:
304+
setattr(
305+
self_instance, polymorphic_property.key, mapper.polymorphic_identity
306+
)
293307
return self_instance
294308

295309
def sqlmodel_validate(

sqlmodel/main.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
)
4242
from sqlalchemy import Enum as sa_Enum
4343
from sqlalchemy.orm import (
44+
InstrumentedAttribute,
4445
Mapped,
46+
MappedColumn,
4547
RelationshipProperty,
46-
declared_attr,
4748
registry,
4849
relationship,
4950
)
@@ -544,6 +545,15 @@ def __new__(
544545
**pydantic_annotations,
545546
**new_cls.__annotations__,
546547
}
548+
# pydantic will set class attribute value inherited from parent as field
549+
# default value, reset it back
550+
base_fields = {}
551+
for base in bases[::-1]:
552+
if issubclass(base, BaseModel):
553+
base_fields.update(base.model_fields)
554+
for k, v in new_cls.model_fields.items():
555+
if isinstance(v.default, InstrumentedAttribute):
556+
new_cls.model_fields[k] = base_fields.get(k)
547557

548558
def get_config(name: str) -> Any:
549559
config_class_value = get_config_value(
@@ -558,9 +568,19 @@ def get_config(name: str) -> Any:
558568

559569
config_table = get_config("table")
560570
if config_table is True:
571+
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
572+
new_cls.__tablename__ = new_cls.__name__.lower()
561573
# If it was passed by kwargs, ensure it's also set in config
562574
set_config_value(model=new_cls, parameter="table", value=config_table)
563575
for k, v in get_model_fields(new_cls).items():
576+
original_v = getattr(new_cls, k, None)
577+
if (
578+
isinstance(original_v, InstrumentedAttribute)
579+
and k not in class_dict
580+
):
581+
# The attribute was already set by SQLAlchemy, don't override it
582+
# Needed for polymorphic models, see #36
583+
continue
564584
col = get_column_from_field(v)
565585
setattr(new_cls, k, col)
566586
# Set a config flag to tell FastAPI that this should be read with a field
@@ -594,7 +614,13 @@ def __init__(
594614
# trying to create a new SQLAlchemy, for a new table, with the same name, that
595615
# triggers an error
596616
base_is_table = any(is_table_model_class(base) for base in bases)
597-
if is_table_model_class(cls) and not base_is_table:
617+
polymorphic_identity = dict_.get("__mapper_args__", {}).get(
618+
"polymorphic_identity"
619+
)
620+
has_polymorphic = polymorphic_identity is not None
621+
622+
# allow polymorphic models inherit from table models
623+
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
598624
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
599625
if rel_info.sa_relationship:
600626
# There's a SQLAlchemy relationship declared, that takes precedence
@@ -641,6 +667,16 @@ def __init__(
641667
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
642668
# Tag: 1.4.36
643669
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
670+
# # patch sqlmodel field's default value to polymorphic_identity
671+
# if has_polymorphic:
672+
# mapper = inspect(cls)
673+
# polymorphic_on = mapper.polymorphic_on
674+
# polymorphic_property = mapper.get_property_by_column(polymorphic_on)
675+
# field = cls.model_fields.get(polymorphic_property.key)
676+
# def get__polymorphic_identity__(kw):
677+
# return polymorphic_identity
678+
# if field:
679+
# field.default_factory = get__polymorphic_identity__
644680
else:
645681
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
646682

@@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
708744
else:
709745
field_info = field.field_info
710746
sa_column = getattr(field_info, "sa_column", Undefined)
711-
if isinstance(sa_column, Column):
747+
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
712748
return sa_column
713749
sa_type = get_sqlalchemy_type(field)
714750
primary_key = getattr(field_info, "primary_key", Undefined)
@@ -772,7 +808,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
772808
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
773809
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
774810
__slots__ = ("__weakref__",)
775-
__tablename__: ClassVar[Union[str, Callable[..., str]]]
776811
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
777812
__name__: ClassVar[str]
778813
metadata: ClassVar[MetaData]
@@ -836,10 +871,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
836871
if not (isinstance(k, str) and k.startswith("_sa_"))
837872
]
838873

839-
@declared_attr # type: ignore
840-
def __tablename__(cls) -> str:
841-
return cls.__name__.lower()
842-
843874
@classmethod
844875
def model_validate(
845876
cls: Type[_TSQLModel],

tests/test_polymorphic_model.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Optional
2+
3+
from sqlalchemy import ForeignKey
4+
from sqlalchemy.orm import mapped_column
5+
from sqlmodel import Field, Session, SQLModel, create_engine, select
6+
7+
8+
def test_polymorphic_joined_table(clear_sqlmodel) -> None:
9+
class Hero(SQLModel, table=True):
10+
__tablename__ = "hero"
11+
id: Optional[int] = Field(default=None, primary_key=True)
12+
hero_type: str = Field(default="hero")
13+
14+
__mapper_args__ = {
15+
"polymorphic_on": "hero_type",
16+
"polymorphic_identity": "hero",
17+
}
18+
19+
class DarkHero(Hero):
20+
__tablename__ = "dark_hero"
21+
id: Optional[int] = Field(
22+
default=None,
23+
sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True),
24+
)
25+
dark_power: str = Field(
26+
default="dark",
27+
sa_column=mapped_column(
28+
nullable=False, use_existing_column=True, default="dark"
29+
),
30+
)
31+
32+
__mapper_args__ = {
33+
"polymorphic_identity": "dark",
34+
}
35+
36+
engine = create_engine("sqlite:///:memory:", echo=True)
37+
SQLModel.metadata.create_all(engine)
38+
with Session(engine) as db:
39+
hero = Hero()
40+
db.add(hero)
41+
dark_hero = DarkHero()
42+
db.add(dark_hero)
43+
db.commit()
44+
statement = select(DarkHero)
45+
result = db.exec(statement).all()
46+
assert len(result) == 1
47+
assert isinstance(result[0].dark_power, str)
48+
49+
50+
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None:
51+
class Hero(SQLModel, table=True):
52+
__tablename__ = "hero"
53+
id: Optional[int] = Field(default=None, primary_key=True)
54+
hero_type: str = Field(default="hero")
55+
56+
__mapper_args__ = {
57+
"polymorphic_on": "hero_type",
58+
"polymorphic_identity": "hero",
59+
}
60+
61+
class DarkHero(Hero):
62+
__tablename__ = "dark_hero"
63+
id: Optional[int] = Field(
64+
default=None,
65+
primary_key=True,
66+
foreign_key="hero.id",
67+
)
68+
dark_power: str = Field(
69+
default="dark",
70+
sa_column=mapped_column(
71+
nullable=False, use_existing_column=True, default="dark"
72+
),
73+
)
74+
75+
__mapper_args__ = {
76+
"polymorphic_identity": "dark",
77+
}
78+
79+
engine = create_engine("sqlite:///:memory:", echo=True)
80+
SQLModel.metadata.create_all(engine)
81+
with Session(engine) as db:
82+
hero = Hero()
83+
db.add(hero)
84+
dark_hero = DarkHero()
85+
db.add(dark_hero)
86+
db.commit()
87+
statement = select(DarkHero)
88+
result = db.exec(statement).all()
89+
assert len(result) == 1
90+
assert isinstance(result[0].dark_power, str)
91+
92+
93+
def test_polymorphic_single_table(clear_sqlmodel) -> None:
94+
class Hero(SQLModel, table=True):
95+
__tablename__ = "hero"
96+
id: Optional[int] = Field(default=None, primary_key=True)
97+
hero_type: str = Field(default="hero")
98+
99+
__mapper_args__ = {
100+
"polymorphic_on": "hero_type",
101+
"polymorphic_identity": "hero",
102+
}
103+
104+
class DarkHero(Hero):
105+
dark_power: str = Field(
106+
default="dark",
107+
sa_column=mapped_column(
108+
nullable=False, use_existing_column=True, default="dark"
109+
),
110+
)
111+
112+
__mapper_args__ = {
113+
"polymorphic_identity": "dark",
114+
}
115+
116+
engine = create_engine("sqlite:///:memory:", echo=True)
117+
SQLModel.metadata.create_all(engine)
118+
with Session(engine) as db:
119+
hero = Hero()
120+
db.add(hero)
121+
dark_hero = DarkHero()
122+
db.add(dark_hero)
123+
db.commit()
124+
statement = select(DarkHero)
125+
result = db.exec(statement).all()
126+
assert len(result) == 1
127+
assert isinstance(result[0].dark_power, str)

0 commit comments

Comments
 (0)