Skip to content

Commit 7d333fe

Browse files
author
John Lyu
committed
fix relationship problem of parent class during polymorphic inherit
1 parent c1dff79 commit 7d333fe

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

sqlmodel/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,14 @@ def __new__(
548548
if issubclass(base, BaseModel):
549549
base_fields.update(get_model_fields(base))
550550
base_annotations.update(base.__annotations__)
551+
if hasattr(base, "__sqlmodel_relationships__"):
552+
for k in base.__sqlmodel_relationships__:
553+
# create a dummy attribute to avoid inherit
554+
# pydantic will treat it as class variables, and will not become fields on model instances
555+
anno = base_annotations.get(k, Any)
556+
dummy_anno = ClassVar[anno]
557+
dict_used["__annotations__"][k] = dummy_anno
558+
551559
if hasattr(base, "__tablename__"):
552560
is_polymorphic = True
553561
# use base_fields overwriting the ones from the class for inherit

tests/test_polymorphic_model.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sqlalchemy import ForeignKey
44
from sqlalchemy.orm import mapped_column
5-
from sqlmodel import Field, Session, SQLModel, create_engine, select
5+
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
66

77
from tests.conftest import needs_pydanticv2
88

@@ -16,7 +16,7 @@ class Hero(SQLModel, table=True):
1616

1717
__mapper_args__ = {
1818
"polymorphic_on": "hero_type",
19-
"polymorphic_identity": "hero",
19+
"polymorphic_identity": "normal_hero",
2020
}
2121

2222
class DarkHero(Hero):
@@ -59,7 +59,7 @@ class Hero(SQLModel, table=True):
5959

6060
__mapper_args__ = {
6161
"polymorphic_on": "hero_type",
62-
"polymorphic_identity": "hero",
62+
"polymorphic_identity": "normal_hero",
6363
}
6464

6565
class DarkHero(Hero):
@@ -103,7 +103,7 @@ class Hero(SQLModel, table=True):
103103

104104
__mapper_args__ = {
105105
"polymorphic_on": "hero_type",
106-
"polymorphic_identity": "hero",
106+
"polymorphic_identity": "normal_hero",
107107
}
108108

109109
class DarkHero(Hero):
@@ -130,3 +130,48 @@ class DarkHero(Hero):
130130
result = db.exec(statement).all()
131131
assert len(result) == 1
132132
assert isinstance(result[0].dark_power, str)
133+
134+
135+
@needs_pydanticv2
136+
def test_polymorphic_relationship(clear_sqlmodel) -> None:
137+
class Tool(SQLModel, table=True):
138+
__tablename__ = "tool_table"
139+
140+
id: int = Field(primary_key=True)
141+
142+
name: str
143+
144+
class Person(SQLModel, table=True):
145+
__tablename__ = "person_table"
146+
147+
id: int = Field(primary_key=True)
148+
149+
discriminator: str
150+
name: str
151+
152+
tool_id: int = Field(foreign_key="tool_table.id")
153+
tool: Tool = Relationship()
154+
155+
__mapper_args__ = {
156+
"polymorphic_on": "discriminator",
157+
"polymorphic_identity": "simple_person",
158+
}
159+
160+
class Worker(Person):
161+
__mapper_args__ = {
162+
"polymorphic_identity": "worker",
163+
}
164+
165+
engine = create_engine("sqlite:///:memory:", echo=True)
166+
SQLModel.metadata.create_all(engine)
167+
with Session(engine) as db:
168+
tool = Tool(id=1, name="Hammer")
169+
db.add(tool)
170+
worker = Worker(id=2, name="Bob", tool_id=1)
171+
db.add(worker)
172+
db.commit()
173+
174+
statement = select(Worker).where(Worker.tool_id == 1)
175+
result = db.exec(statement).all()
176+
assert len(result) == 1
177+
assert isinstance(result[0].tool, Tool)

0 commit comments

Comments
 (0)