Skip to content

Commit 015601c

Browse files
author
John Lyu
committed
improve code structure
1 parent a3044bb commit 015601c

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

sqlmodel/_compat.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool:
6666
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
6767

6868

69+
def set_polymorphic_default_value(self_instance, values):
70+
"""By defalut, when init a model, pydantic will set the polymorphic_on
71+
value to field default value. But when inherit a model, the polymorphic_on
72+
should be set to polymorphic_identity value by default."""
73+
cls = type(self_instance)
74+
mapper = inspect(cls)
75+
if isinstance(mapper, Mapper):
76+
polymorphic_on = mapper.polymorphic_on
77+
if polymorphic_on is not None:
78+
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
79+
field_info = get_model_fields(cls).get(polymorphic_property.key)
80+
if field_info:
81+
v = values.get(polymorphic_property.key)
82+
# if model is inherited or polymorphic_on is not explicitly set
83+
# set the polymorphic_on by default
84+
if mapper.inherits or v is None:
85+
setattr(
86+
self_instance,
87+
polymorphic_property.key,
88+
mapper.polymorphic_identity,
89+
)
90+
91+
6992
@contextmanager
7093
def partial_init() -> Generator[None, None, None]:
7194
token = finish_init.set(False)
@@ -293,22 +316,7 @@ def sqlmodel_table_construct(
293316
setattr(self_instance, key, value)
294317
# End SQLModel override
295318
# Override polymorphic_on default value
296-
mapper = inspect(cls)
297-
if isinstance(mapper, Mapper):
298-
polymorphic_on = mapper.polymorphic_on
299-
if polymorphic_on is not None:
300-
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
301-
field_info = cls.model_fields.get(polymorphic_property.key)
302-
if field_info:
303-
v = values.get(polymorphic_property.key)
304-
# if model is inherited or polymorphic_on is not explicitly set
305-
# set the polymorphic_on by default
306-
if mapper.inherits or v is None:
307-
setattr(
308-
self_instance,
309-
polymorphic_property.key,
310-
mapper.polymorphic_identity,
311-
)
319+
set_polymorphic_default_value(self_instance, values)
312320
return self_instance
313321

314322
def sqlmodel_validate(
@@ -592,3 +600,5 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
592600
for key in non_pydantic_keys:
593601
if key in self.__sqlmodel_relationships__:
594602
setattr(self, key, data[key])
603+
# Override polymorphic_on default value
604+
set_polymorphic_default_value(self, values)

tests/test_polymorphic_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class DarkHero(Hero):
5151

5252

5353
@needs_pydanticv2
54-
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None:
54+
def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None:
5555
class Hero(SQLModel, table=True):
5656
__tablename__ = "hero"
5757
id: Optional[int] = Field(default=None, primary_key=True)
@@ -123,7 +123,7 @@ class DarkHero(Hero):
123123
with Session(engine) as db:
124124
hero = Hero()
125125
db.add(hero)
126-
dark_hero = DarkHero()
126+
dark_hero = DarkHero(dark_power="pokey")
127127
db.add(dark_hero)
128128
db.commit()
129129
statement = select(DarkHero)

0 commit comments

Comments
 (0)