@@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool:
6666finish_init : ContextVar [bool ] = ContextVar ("finish_init" , default = True )
6767
6868
69+ def set_polymorphic_default_value (self_instance , values ):
70+ """By defalut, when init a model, pydantic will set the polymorphic_on
71+ value to field default value. But when inherit a model, the polymorphic_on
72+ should be set to polymorphic_identity value by default."""
73+ cls = type (self_instance )
74+ mapper = inspect (cls )
75+ if isinstance (mapper , Mapper ):
76+ polymorphic_on = mapper .polymorphic_on
77+ if polymorphic_on is not None :
78+ polymorphic_property = mapper .get_property_by_column (polymorphic_on )
79+ field_info = get_model_fields (cls ).get (polymorphic_property .key )
80+ if field_info :
81+ v = values .get (polymorphic_property .key )
82+ # if model is inherited or polymorphic_on is not explicitly set
83+ # set the polymorphic_on by default
84+ if mapper .inherits or v is None :
85+ setattr (
86+ self_instance ,
87+ polymorphic_property .key ,
88+ mapper .polymorphic_identity ,
89+ )
90+
91+
6992@contextmanager
7093def partial_init () -> Generator [None , None , None ]:
7194 token = finish_init .set (False )
@@ -293,22 +316,7 @@ def sqlmodel_table_construct(
293316 setattr (self_instance , key , value )
294317 # End SQLModel override
295318 # Override polymorphic_on default value
296- mapper = inspect (cls )
297- if isinstance (mapper , Mapper ):
298- polymorphic_on = mapper .polymorphic_on
299- if polymorphic_on is not None :
300- polymorphic_property = mapper .get_property_by_column (polymorphic_on )
301- field_info = cls .model_fields .get (polymorphic_property .key )
302- if field_info :
303- v = values .get (polymorphic_property .key )
304- # if model is inherited or polymorphic_on is not explicitly set
305- # set the polymorphic_on by default
306- if mapper .inherits or v is None :
307- setattr (
308- self_instance ,
309- polymorphic_property .key ,
310- mapper .polymorphic_identity ,
311- )
319+ set_polymorphic_default_value (self_instance , values )
312320 return self_instance
313321
314322 def sqlmodel_validate (
@@ -592,3 +600,5 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
592600 for key in non_pydantic_keys :
593601 if key in self .__sqlmodel_relationships__ :
594602 setattr (self , key , data [key ])
603+ # Override polymorphic_on default value
604+ set_polymorphic_default_value (self , values )
0 commit comments